You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.8 KiB

  1. # maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, List, ClassVar
  17. import random
  18. import string
  19. from attr import dataclass
  20. import asyncpg
  21. from mautrix.types import UserID
  22. from .base import Base
  23. from .pack import Pack
  24. from .access_token import AccessToken
  25. @dataclass(kw_only=True)
  26. class User(Base):
  27. token_charset: ClassVar[str] = string.ascii_letters + string.digits
  28. id: UserID
  29. widget_secret: str
  30. homeserver_url: str
  31. @classmethod
  32. def _random_token(cls) -> str:
  33. return "".join(random.choices(cls.token_charset, k=64))
  34. @classmethod
  35. def new(cls, id: UserID, homeserver_url: str) -> 'User':
  36. return User(id=id, widget_secret=cls._random_token(), homeserver_url=homeserver_url)
  37. @classmethod
  38. async def get(cls, id: UserID) -> Optional['User']:
  39. q = 'SELECT id, widget_secret, homeserver_url FROM "user" WHERE id=$1'
  40. row: asyncpg.Record = await cls.db.fetchrow(q, id)
  41. if row is None:
  42. return None
  43. return cls(**row)
  44. async def regenerate_widget_secret(self) -> None:
  45. self.widget_secret = self._random_token()
  46. await self.db.execute('UPDATE "user" SET widget_secret=$1 WHERE id=$2',
  47. self.widget_secret, self.id)
  48. async def set_homeserver_url(self, url: str) -> None:
  49. self.homeserver_url = url
  50. await self.db.execute('UPDATE "user" SET homeserver_url=$1 WHERE id=$2', url, self.id)
  51. async def new_access_token(self, ip: str) -> str:
  52. token = self._random_token()
  53. token_id = await AccessToken.insert(self.id, token, ip)
  54. return f"{token_id}:{token}"
  55. async def delete(self) -> None:
  56. await self.db.execute('DELETE FROM "user" WHERE id=$1', self.id)
  57. async def insert(self) -> None:
  58. q = 'INSERT INTO "user" (id, widget_secret, homeserver_url) VALUES ($1, $2, $3)'
  59. await self.db.execute(q, self.id, self.widget_secret, self.homeserver_url)
  60. async def get_packs(self) -> List[Pack]:
  61. res = await self.db.fetch("SELECT id, owner, title, meta FROM user_pack "
  62. "LEFT JOIN pack ON pack.id=user_pack.pack_id "
  63. 'WHERE user_id=$1 ORDER BY "order"', self.id)
  64. return [Pack(**row) for row in res]
  65. async def get_pack(self, pack_id: str) -> Optional[Pack]:
  66. row = await self.db.fetchrow("SELECT id, owner, title, meta FROM user_pack "
  67. "LEFT JOIN pack ON pack.id=user_pack.pack_id "
  68. "WHERE user_id=$1 AND pack_id=$2", self.id, pack_id)
  69. if row is None:
  70. return None
  71. return Pack(**row)
  72. async def set_packs(self, packs: List[Pack]) -> None:
  73. data = ((self.id, pack.id, order)
  74. for order, pack in enumerate(packs))
  75. columns = ["user_id", "pack_id", "order"]
  76. async with self.db.acquire() as conn, conn.transaction():
  77. await conn.execute("DELETE FROM user_pack WHERE user_id=$1", self.id)
  78. await conn.copy_records_to_table("user_pack", records=data, columns=columns)