You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.8 KiB

4 years ago
4 years ago
  1. # maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, ClassVar
  17. from datetime import datetime, timedelta
  18. import hashlib
  19. from attr import dataclass
  20. import asyncpg
  21. from mautrix.types import UserID
  22. from .base import Base
  23. @dataclass(kw_only=True)
  24. class AccessToken(Base):
  25. token_expiry: ClassVar[timedelta] = timedelta(days=1)
  26. user_id: UserID
  27. token_id: int
  28. token_hash: bytes
  29. last_seen_ip: str
  30. last_seen_date: datetime
  31. @classmethod
  32. async def get(cls, token_id: int) -> Optional['AccessToken']:
  33. q = ("SELECT user_id, token_hash, last_seen_ip, last_seen_date "
  34. "FROM access_token WHERE token_id=$1")
  35. row: asyncpg.Record = await cls.db.fetchrow(q, token_id)
  36. if row is None:
  37. return None
  38. return cls(**row, token_id=token_id)
  39. async def update_ip(self, ip: str) -> None:
  40. if self.last_seen_ip == ip and (self.last_seen_date.replace(second=0, microsecond=0)
  41. == datetime.now().replace(second=0, microsecond=0)):
  42. # Same IP and last seen on this minute, skip update
  43. return
  44. q = ("UPDATE access_token SET last_seen_ip=$2, last_seen_date=current_timestamp "
  45. "WHERE token_id=$1 RETURNING last_seen_date")
  46. self.last_seen_date = await self.db.fetchval(q, self.token_id, ip)
  47. self.last_seen_ip = ip
  48. def check(self, token: str) -> bool:
  49. return self.token_hash == hashlib.sha256(token.encode("utf-8")).digest()
  50. @property
  51. def expired(self) -> bool:
  52. return self.last_seen_date + self.token_expiry < datetime.now()
  53. async def delete(self) -> None:
  54. await self.db.execute("DELETE FROM access_token WHERE token_id=$1", self.token_id)
  55. @classmethod
  56. async def insert(cls, user_id: UserID, token: str, ip: str) -> int:
  57. q = ("INSERT INTO access_token (user_id, token_hash, last_seen_ip, last_seen_date) "
  58. "VALUES ($1, $2, $3, current_timestamp) RETURNING token_id")
  59. hashed = hashlib.sha256(token.encode("utf-8")).digest()
  60. return await cls.db.fetchval(q, user_id, hashed, ip)