mirror of https://gitlab.com/tildes/tildes.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
6.2 KiB
188 lines
6.2 KiB
# Copyright (c) 2018 Tildes contributors <code@tildes.net>
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
"""Constants/classes/functions related to the database."""
|
|
|
|
import enum
|
|
from typing import Any, Callable, List, Optional
|
|
|
|
from dateutil.rrule import rrule, rrulestr
|
|
from pyramid.paster import bootstrap
|
|
from sqlalchemy import cast, func
|
|
from sqlalchemy.dialects.postgresql import ARRAY
|
|
from sqlalchemy.engine.interfaces import Dialect
|
|
from sqlalchemy.orm.session import Session
|
|
from sqlalchemy.types import Text, TypeDecorator, UserDefinedType
|
|
from sqlalchemy_utils import Ltree, LtreeType
|
|
from sqlalchemy_utils.types.ltree import LQUERY
|
|
|
|
from tildes.lib.datetime import rrule_to_str
|
|
|
|
|
|
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
|
NOT_NULL_ERROR_CODE = 23502
|
|
|
|
|
|
def get_session_from_config(config_path: str) -> Session:
|
|
"""Get a database session from a config file (specified by path)."""
|
|
env = bootstrap(config_path)
|
|
session_factory = env["registry"]["db_session_factory"]
|
|
return session_factory()
|
|
|
|
|
|
class LockSpaces(enum.Enum):
|
|
"""Enum of valid options for "lock spaces" used for advisory locks."""
|
|
|
|
GENERATE_INVITE_CODE = enum.auto()
|
|
|
|
|
|
def obtain_transaction_lock(
|
|
session: Session, lock_space: Optional[str], lock_value: int
|
|
) -> None:
|
|
"""Obtain a transaction-level advisory lock from PostgreSQL.
|
|
|
|
The lock_space arg must be either None or the name of one of the members of the
|
|
LockSpaces enum (case-insensitive). Contention for a lock will only occur when both
|
|
lock_space and lock_value have the same values.
|
|
"""
|
|
if lock_space:
|
|
try:
|
|
lock_space_value = LockSpaces[lock_space.upper()].value
|
|
except KeyError as exc:
|
|
raise ValueError("Invalid lock space: %s" % lock_space) from exc
|
|
|
|
session.query(func.pg_advisory_xact_lock(lock_space_value, lock_value)).one()
|
|
else:
|
|
session.query(func.pg_advisory_xact_lock(lock_value)).one()
|
|
|
|
|
|
class CIText(UserDefinedType):
|
|
"""PostgreSQL citext type for case-insensitive text values.
|
|
|
|
For more info, see the docs:
|
|
https://www.postgresql.org/docs/current/static/citext.html
|
|
"""
|
|
|
|
python_type = str
|
|
|
|
def get_col_spec(self, **kw: Any) -> str:
|
|
"""Return the type name (for creating columns and so on)."""
|
|
# pylint: disable=no-self-use,unused-argument
|
|
return "CITEXT"
|
|
|
|
def bind_processor(self, dialect: Dialect) -> Callable:
|
|
"""Return a conversion function for processing bind values."""
|
|
|
|
def process(value: Any) -> Any:
|
|
return value
|
|
|
|
return process
|
|
|
|
def result_processor(self, dialect: Dialect, coltype: Any) -> Callable:
|
|
"""Return a conversion function for processing result row values."""
|
|
|
|
def process(value: Any) -> Any:
|
|
return value
|
|
|
|
return process
|
|
|
|
|
|
class ArrayOfLtree(ARRAY):
|
|
"""Workaround class to support ltree[] columns which don't work "normally".
|
|
|
|
This is heavily based on the ArrayOfEnum class from the SQLAlchemy docs:
|
|
http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#using-enum-with-array
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize as ARRAY(LtreeType)."""
|
|
super().__init__(LtreeType)
|
|
|
|
def bind_expression(self, bindvalue: Any) -> Any:
|
|
"""Convert bind value to an SQL expression."""
|
|
return cast(bindvalue, self)
|
|
|
|
def result_processor(self, dialect: Any, coltype: Any) -> Callable:
|
|
"""Return a conversion function for processing result row values."""
|
|
super_rp = super().result_processor(dialect, coltype)
|
|
|
|
def handle_raw_string(value: str) -> List[str]:
|
|
if not (value.startswith("{") and value.endswith("}")):
|
|
raise ValueError("%s is not an array value" % value)
|
|
|
|
# trim off the surrounding braces
|
|
value = value[1:-1]
|
|
|
|
# if there's nothing left, return an empty list
|
|
if not value:
|
|
return []
|
|
|
|
return value.split(",")
|
|
|
|
def process(value: Optional[str]) -> Optional[List[str]]:
|
|
if value is None:
|
|
return None
|
|
|
|
return super_rp(handle_raw_string(value))
|
|
|
|
return process
|
|
|
|
class comparator_factory(ARRAY.comparator_factory): # noqa
|
|
"""Add custom comparison functions.
|
|
|
|
The ancestor_of, descendant_of, and lquery functions are supported by LtreeType,
|
|
so this duplicates them here so they can be used on ArrayOfLtree too.
|
|
"""
|
|
|
|
def ancestor_of(self, other): # type: ignore
|
|
"""Return whether the array contains any ancestor of `other`."""
|
|
return self.op("@>")(other)
|
|
|
|
def descendant_of(self, other): # type: ignore
|
|
"""Return whether the array contains any descendant of `other`."""
|
|
return self.op("<@")(other)
|
|
|
|
def lquery(self, other): # type: ignore
|
|
"""Return whether the array matches the lquery/lqueries in `other`."""
|
|
if isinstance(other, list):
|
|
return self.op("?")(cast(other, ARRAY(LQUERY)))
|
|
else:
|
|
return self.op("~")(cast(other, LQUERY))
|
|
|
|
|
|
class RecurrenceRule(TypeDecorator):
|
|
"""Stores a dateutil rrule in the database as text."""
|
|
|
|
# pylint: disable=abstract-method
|
|
|
|
impl = Text
|
|
|
|
def process_bind_param(self, value: rrule, dialect: Dialect) -> str:
|
|
"""Convert the rrule value to a string to store it."""
|
|
if value is None:
|
|
return value
|
|
|
|
return rrule_to_str(value)
|
|
|
|
def process_result_value(self, value: str, dialect: Dialect) -> rrule:
|
|
"""Convert the stored string to an rrule."""
|
|
if value is None:
|
|
return value
|
|
|
|
return rrulestr(value)
|
|
|
|
|
|
class TagList(TypeDecorator):
|
|
"""Stores a list of tags in the database as an array of ltree."""
|
|
|
|
# pylint: disable=abstract-method
|
|
|
|
impl = ArrayOfLtree
|
|
|
|
def process_bind_param(self, value: str, dialect: Dialect) -> List[Ltree]:
|
|
"""Convert the value to ltree[] for storing."""
|
|
return [Ltree(tag.replace(" ", "_")) for tag in value]
|
|
|
|
def process_result_value(self, value: List[Ltree], dialect: Dialect) -> List[str]:
|
|
"""Convert the stored value to a list of strings."""
|
|
return [str(tag).replace("_", " ") for tag in value]
|