diff --git a/tildes/tildes/lib/database.py b/tildes/tildes/lib/database.py index 05d255a..c6836d6 100644 --- a/tildes/tildes/lib/database.py +++ b/tildes/tildes/lib/database.py @@ -14,6 +14,7 @@ from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.orm.session import Session from sqlalchemy.types import UserDefinedType from sqlalchemy_utils import LtreeType +from sqlalchemy_utils.types.ltree import LQUERY # https://www.postgresql.org/docs/current/static/errcodes-appendix.html @@ -128,8 +129,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors class comparator_factory(ARRAY.comparator_factory): """Add custom comparison functions. - The ancestor_of and descendant_of functions are supported by LtreeType, so this - duplicates them here so they can be used on ArrayOfLtree too. + The ancestor_of, descendant_of, and lquery functions are supported by LtreeType, + so this duplicates them here so they can be used on ArrayOfLtree too. """ def ancestor_of(self, other): # type: ignore @@ -139,3 +140,10 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors def descendant_of(self, other): # type: ignore """Return whether the array contains any descendant of `other`.""" return self.op("<@")(other) + + def lquery(self, other): # type: ignore + """Return whether the array matches the lquery/lqueries in `other`.""" + if isinstance(other, list): + return self.op("?")(cast(other, ARRAY(LQUERY))) + else: + return self.op("~")(other) diff --git a/tildes/tildes/models/topic/topic_query.py b/tildes/tildes/models/topic/topic_query.py index 7c8ea83..46adfdf 100644 --- a/tildes/tildes/models/topic/topic_query.py +++ b/tildes/tildes/models/topic/topic_query.py @@ -134,14 +134,15 @@ class TopicQuery(PaginatedQuery): return self.filter(Topic.created_time > start_time) def has_tag(self, tag: Ltree) -> "TopicQuery": - """Restrict the topics to ones with a specific tag (generative).""" - # casting tag to string really shouldn't be necessary, but some kind of strange - # interaction seems to be happening with the ArrayOfLtree class, this will need - # some investigation - tag = str(tag) + """Restrict the topics to ones with a specific tag (generative). + + Note that this method searches for topics that have any tag that either starts + or ends with the specified tag, not only exact/full matches. + """ + queries = [f"{tag}.*", f"*.{tag}"] # pylint: disable=protected-access - return self.filter(Topic._tags.descendant_of(tag)) # type: ignore + return self.filter(Topic._tags.lquery(queries)) # type: ignore def search(self, query: str) -> "TopicQuery": """Restrict the topics to ones that match a search query (generative)."""