@ -3,7 +3,8 @@
""" Contains the PaginatedQuery and PaginatedResults classes. """
from typing import Any , Iterator , List , Optional , TypeVar
from itertools import chain
from typing import Any , Iterator , List , Optional , Sequence , TypeVar
from pyramid.request import Request
from sqlalchemy import Column , func , inspect
@ -29,6 +30,8 @@ class PaginatedQuery(ModelQuery):
self . after_id : Optional [ int ] = None
self . before_id : Optional [ int ] = None
self . _anchor_table = model_cls . __table__
def __iter__ ( self ) - > Iterator [ ModelType ] :
""" Iterate over the results of the query, reversed if necessary. """
if not self . is_reversed :
@ -44,14 +47,22 @@ class PaginatedQuery(ModelQuery):
if not self . _sort_column :
raise AttributeError
# always add a final sort by the ID so keyset pagination works properly
if self . is_anchor_same_type :
# add a final sort by the ID so keyset pagination works properly
return [ self . _sort_column ] + list ( self . model_cls . __table__ . primary_key )
else :
return [ self . _sort_column ]
@property
def sorting_columns_desc ( self ) - > List [ Column ] :
""" Return descending versions of the sorting columns. """
return [ col . desc ( ) for col in self . sorting_columns ]
@property
def is_anchor_same_type ( self ) - > bool :
""" Return whether the anchor type is the same as the overall model_cls. """
return self . _anchor_table == self . model_cls . __table__
@property
def is_reversed ( self ) - > bool :
""" Return whether the query is operating " in reverse " .
@ -73,6 +84,13 @@ class PaginatedQuery(ModelQuery):
"""
return bool ( self . before_id )
def anchor_type ( self , anchor_type : str ) - > " PaginatedQuery " :
""" Set the type of the " anchor " (before/after item) (generative). """
anchor_table_name = anchor_type + " s "
self . _anchor_table = self . model_cls . metadata . tables . get ( anchor_table_name )
return self
def after_id36 ( self , id36 : str ) - > " PaginatedQuery " :
""" Restrict the query to results after an id36 (generative). """
if self . before_id :
@ -127,12 +145,21 @@ class PaginatedQuery(ModelQuery):
def _anchor_subquery ( self , anchor_id : int ) - > Any :
""" Return a subquery to get comparison values for the anchor item. """
if len ( self . model_cls . __table__ . primary_key ) > 1 :
if len ( self . _anchor_table . primary_key ) > 1 :
raise TypeError ( " Only single-col primary key tables are supported " )
id_column = list ( self . model_cls . __table__ . primary_key ) [ 0 ]
id_column = list ( self . _anchor_table . primary_key ) [ 0 ]
if self . is_anchor_same_type :
columns = self . sorting_columns
else :
columns = [
self . _anchor_table . columns . get ( column . name )
for column in self . sorting_columns
]
return (
self . request . db_session . query ( * self . sorting_columns )
self . request . db_session . query ( * columns )
. filter ( id_column == anchor_id )
. subquery ( )
)
@ -170,6 +197,7 @@ class PaginatedResults:
def __init__ ( self , query : PaginatedQuery , per_page : int ) :
""" Fetch results from a PaginatedQuery. """
self . query = query
self . per_page = per_page
# if the query had `before` or `after` restrictions, there must be a page in
@ -227,3 +255,67 @@ class PaginatedResults:
prev_id = inspect ( self . results [ 0 ] ) . identity [ 0 ]
return id_to_id36 ( prev_id )
class MixedPaginatedResults ( PaginatedResults ) :
""" Merged result from multiple PaginatedResults, consisting of different types. """
def __init__ ( self , paginated_results : Sequence [ PaginatedResults ] ) :
# pylint: disable=super-init-not-called,protected-access
""" Merge all the supplied results into a single one. """
sort_column_name = paginated_results [ 0 ] . query . _sort_column . name
if any (
[ r . query . _sort_column . name != sort_column_name for r in paginated_results ]
) :
raise ValueError ( " All results must by sorted by the same column. " )
reverse_sort = paginated_results [ 0 ] . query . sort_desc
if any ( [ r . query . sort_desc != reverse_sort for r in paginated_results ] ) :
raise ValueError ( " All results must by sorted in the same direction. " )
# merge all the results into one list and sort it
self . results = sorted (
chain . from_iterable ( paginated_results ) ,
key = lambda post : getattr ( post , sort_column_name ) ,
reverse = reverse_sort ,
)
self . per_page = min ( [ r . per_page for r in paginated_results ] )
if len ( self . results ) > self . per_page :
self . has_next_page = True
self . results = self . results [ : self . per_page ]
else :
self . has_next_page = any ( [ r . has_next_page for r in paginated_results ] )
self . has_prev_page = any ( [ r . has_prev_page for r in paginated_results ] )
@property
def next_page_after_id36 ( self ) - > str :
""" Return " after " ID36 that should be used to fetch the next page. """
if not self . has_next_page :
raise AttributeError
item = self . results [ - 1 ]
next_id = inspect ( item ) . identity [ 0 ]
next_id36 = id_to_id36 ( next_id )
type_char = item . __class__ . __name__ . lower ( ) [ 0 ]
return f " {type_char}-{next_id36} "
@property
def prev_page_before_id36 ( self ) - > str :
""" Return " before " ID36 that should be used to fetch the prev page. """
if not self . has_prev_page :
raise AttributeError
item = self . results [ 0 ]
prev_id = inspect ( item ) . identity [ 0 ]
prev_id36 = id_to_id36 ( prev_id )
type_char = item . __class__ . __name__ . lower ( ) [ 0 ]
return f " {type_char}-{prev_id36} "