Browse Source

Apply Black code formatter

This commit contains only changes that were made automatically by Black
(except for some minor fixes to string un-wrapping and two
format-disabling blocks in the user and group schemas). Some manual
cleanup/adjustments will probably need to be made in a follow-up commit,
but this one contains the result of running Black on the codebase
without significant further manual tweaking.
merge-requests/26/head
Deimos 7 years ago
parent
commit
09cf3c47f4
  1. 22
      tildes/alembic/env.py
  2. 26
      tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py
  3. 13
      tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py
  4. 35
      tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py
  5. 124
      tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py
  6. 21
      tildes/consumers/comment_user_mentions_generator.py
  7. 29
      tildes/consumers/topic_metadata_generator.py
  8. 57
      tildes/scripts/breached_passwords.py
  9. 51
      tildes/scripts/clean_private_data.py
  10. 50
      tildes/scripts/initialize_db.py
  11. 6
      tildes/setup.py
  12. 71
      tildes/tests/conftest.py
  13. 8
      tildes/tests/fixtures.py
  14. 64
      tildes/tests/test_comment.py
  15. 72
      tildes/tests/test_comment_user_mentions.py
  16. 14
      tildes/tests/test_datetime.py
  17. 41
      tildes/tests/test_group.py
  18. 10
      tildes/tests/test_id.py
  19. 167
      tildes/tests/test_markdown.py
  20. 22
      tildes/tests/test_markdown_field.py
  21. 36
      tildes/tests/test_messages.py
  22. 2
      tildes/tests/test_metrics.py
  23. 36
      tildes/tests/test_ratelimit.py
  24. 22
      tildes/tests/test_simplestring_field.py
  25. 58
      tildes/tests/test_string.py
  26. 36
      tildes/tests/test_title.py
  27. 44
      tildes/tests/test_topic.py
  28. 37
      tildes/tests/test_topic_permissions.py
  29. 22
      tildes/tests/test_topic_tags.py
  30. 6
      tildes/tests/test_triggers_comments.py
  31. 22
      tildes/tests/test_url.py
  32. 51
      tildes/tests/test_user.py
  33. 16
      tildes/tests/test_username.py
  34. 10
      tildes/tests/test_webassets.py
  35. 6
      tildes/tests/webtests/test_user_page.py
  36. 93
      tildes/tildes/__init__.py
  37. 6
      tildes/tildes/api.py
  38. 46
      tildes/tildes/auth.py
  39. 27
      tildes/tildes/database.py
  40. 20
      tildes/tildes/enums.py
  41. 26
      tildes/tildes/jinja.py
  42. 6
      tildes/tildes/json.py
  43. 16
      tildes/tildes/lib/amqp.py
  44. 12
      tildes/tildes/lib/cmark.py
  45. 26
      tildes/tildes/lib/database.py
  46. 30
      tildes/tildes/lib/datetime.py
  47. 3
      tildes/tildes/lib/hash.py
  48. 10
      tildes/tildes/lib/id.py
  49. 181
      tildes/tildes/lib/markdown.py
  50. 2
      tildes/tildes/lib/message.py
  51. 11
      tildes/tildes/lib/password.py
  52. 95
      tildes/tildes/lib/ratelimit.py
  53. 45
      tildes/tildes/lib/string.py
  54. 4
      tildes/tildes/lib/url.py
  55. 60
      tildes/tildes/metrics.py
  56. 97
      tildes/tildes/models/comment/comment.py
  57. 71
      tildes/tildes/models/comment/comment_notification.py
  58. 10
      tildes/tildes/models/comment/comment_notification_query.py
  59. 6
      tildes/tildes/models/comment/comment_query.py
  60. 29
      tildes/tildes/models/comment/comment_tag.py
  61. 23
      tildes/tildes/models/comment/comment_tree.py
  62. 20
      tildes/tildes/models/comment/comment_vote.py
  63. 34
      tildes/tildes/models/database_model.py
  64. 44
      tildes/tildes/models/group/group.py
  65. 6
      tildes/tildes/models/group/group_query.py
  66. 20
      tildes/tildes/models/group/group_subscription.py
  67. 135
      tildes/tildes/models/log/log.py
  68. 85
      tildes/tildes/models/message/message.py
  69. 37
      tildes/tildes/models/model_query.py
  70. 14
      tildes/tildes/models/pagination.py
  71. 174
      tildes/tildes/models/topic/topic.py
  72. 45
      tildes/tildes/models/topic/topic_query.py
  73. 26
      tildes/tildes/models/topic/topic_visit.py
  74. 20
      tildes/tildes/models/topic/topic_vote.py
  75. 70
      tildes/tildes/models/user/user.py
  76. 16
      tildes/tildes/models/user/user_group_settings.py
  77. 37
      tildes/tildes/models/user/user_invite_code.py
  78. 6
      tildes/tildes/resources/__init__.py
  79. 16
      tildes/tildes/resources/comment.py
  80. 11
      tildes/tildes/resources/group.py
  81. 11
      tildes/tildes/resources/message.py
  82. 9
      tildes/tildes/resources/topic.py
  83. 5
      tildes/tildes/resources/user.py
  84. 160
      tildes/tildes/routes.py
  85. 45
      tildes/tildes/schemas/fields.py
  86. 27
      tildes/tildes/schemas/group.py
  87. 58
      tildes/tildes/schemas/topic.py
  88. 13
      tildes/tildes/schemas/topic_listing.py
  89. 46
      tildes/tildes/schemas/user.py
  90. 2
      tildes/tildes/views/__init__.py
  91. 2
      tildes/tildes/views/api/v0/group.py
  92. 4
      tildes/tildes/views/api/v0/topic.py
  93. 2
      tildes/tildes/views/api/v0/user.py
  94. 167
      tildes/tildes/views/api/web/comment.py
  95. 21
      tildes/tildes/views/api/web/exceptions.py
  96. 41
      tildes/tildes/views/api/web/group.py
  97. 16
      tildes/tildes/views/api/web/message.py
  98. 172
      tildes/tildes/views/api/web/topic.py
  99. 139
      tildes/tildes/views/api/web/user.py
  100. 16
      tildes/tildes/views/decorators.py

22
tildes/alembic/env.py

@ -12,12 +12,7 @@ config = context.config
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
# import all DatabaseModel subclasses here for autogenerate support # import all DatabaseModel subclasses here for autogenerate support
from tildes.models.comment import (
Comment,
CommentNotification,
CommentTag,
CommentVote,
)
from tildes.models.comment import Comment, CommentNotification, CommentTag, CommentVote
from tildes.models.group import Group, GroupSubscription from tildes.models.group import Group, GroupSubscription
from tildes.models.log import Log from tildes.models.log import Log
from tildes.models.message import MessageConversation, MessageReply from tildes.models.message import MessageConversation, MessageReply
@ -25,6 +20,7 @@ from tildes.models.topic import Topic, TopicVisit, TopicVote
from tildes.models.user import User, UserGroupSettings, UserInviteCode from tildes.models.user import User, UserGroupSettings, UserInviteCode
from tildes.models import DatabaseModel from tildes.models import DatabaseModel
target_metadata = DatabaseModel.metadata target_metadata = DatabaseModel.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@ -46,8 +42,7 @@ def run_migrations_offline():
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True)
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
@ -62,18 +57,17 @@ def run_migrations_online():
""" """
connectable = engine_from_config( connectable = engine_from_config(
config.get_section(config.config_ini_section), config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:

26
tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py

@ -10,17 +10,33 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '2512581c91b3'
revision = "2512581c91b3"
down_revision = None down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column('users', sa.Column('open_new_tab_external', sa.Boolean(), server_default='false', nullable=False))
op.add_column('users', sa.Column('open_new_tab_internal', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_external",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
op.add_column(
"users",
sa.Column(
"open_new_tab_internal",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
def downgrade(): def downgrade():
op.drop_column('users', 'open_new_tab_internal')
op.drop_column('users', 'open_new_tab_external')
op.drop_column("users", "open_new_tab_internal")
op.drop_column("users", "open_new_tab_external")

13
tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py

@ -10,15 +10,20 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'de83b8750123'
down_revision = '2512581c91b3'
revision = "de83b8750123"
down_revision = "2512581c91b3"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column('users', sa.Column('open_new_tab_text', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_text", sa.Boolean(), server_default="false", nullable=False
),
)
def downgrade(): def downgrade():
op.drop_column('users', 'open_new_tab_text')
op.drop_column("users", "open_new_tab_text")

35
tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py

@ -9,8 +9,8 @@ from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'f1ecbf24c212'
down_revision = 'de83b8750123'
revision = "f1ecbf24c212"
down_revision = "de83b8750123"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -20,18 +20,18 @@ def upgrade():
connection = None connection = None
if not op.get_context().as_sql: if not op.get_context().as_sql:
connection = op.get_bind() connection = op.get_bind()
connection.execution_options(isolation_level='AUTOCOMMIT')
connection.execution_options(isolation_level="AUTOCOMMIT")
op.execute( op.execute(
"ALTER TYPE commentnotificationtype "
"ADD VALUE IF NOT EXISTS 'USER_MENTION'"
"ALTER TYPE commentnotificationtype ADD VALUE IF NOT EXISTS 'USER_MENTION'"
) )
# re-activate the transaction for any future migrations # re-activate the transaction for any future migrations
if connection is not None: if connection is not None:
connection.execution_options(isolation_level='READ_COMMITTED')
connection.execution_options(isolation_level="READ_COMMITTED")
op.execute('''
op.execute(
"""
CREATE OR REPLACE FUNCTION send_rabbitmq_message_for_comment() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION send_rabbitmq_message_for_comment() RETURNS TRIGGER AS $$
DECLARE DECLARE
affected_row RECORD; affected_row RECORD;
@ -50,23 +50,28 @@ def upgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_insert CREATE TRIGGER send_rabbitmq_message_for_comment_insert
AFTER INSERT ON comments AFTER INSERT ON comments
FOR EACH ROW FOR EACH ROW
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('created'); EXECUTE PROCEDURE send_rabbitmq_message_for_comment('created');
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_edit CREATE TRIGGER send_rabbitmq_message_for_comment_edit
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.markdown IS DISTINCT FROM NEW.markdown) WHEN (OLD.markdown IS DISTINCT FROM NEW.markdown)
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('edited'); EXECUTE PROCEDURE send_rabbitmq_message_for_comment('edited');
''')
"""
)
def downgrade(): def downgrade():
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments')
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments')
op.execute('DROP FUNCTION send_rabbitmq_message_for_comment')
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments")
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments")
op.execute("DROP FUNCTION send_rabbitmq_message_for_comment")

124
tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py

@ -10,8 +10,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'fab922a8bb04'
down_revision = 'f1ecbf24c212'
revision = "fab922a8bb04"
down_revision = "f1ecbf24c212"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -19,17 +19,20 @@ depends_on = None
def upgrade(): def upgrade():
# comment_notifications # comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments") op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN ((OLD.is_deleted = false AND NEW.is_deleted = true) WHEN ((OLD.is_deleted = false AND NEW.is_deleted = true)
OR (OLD.is_removed = false AND NEW.is_removed = true)) OR (OLD.is_removed = false AND NEW.is_removed = true))
EXECUTE PROCEDURE delete_comment_notifications(); EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments # comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN BEGIN
IF (NEW.is_deleted = TRUE) THEN IF (NEW.is_deleted = TRUE) THEN
@ -41,17 +44,21 @@ def upgrade():
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments") op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments BEFORE UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE set_comment_deleted_time(); EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_removed_time() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION set_comment_removed_time() RETURNS TRIGGER AS $$
BEGIN BEGIN
IF (NEW.is_removed = TRUE) THEN IF (NEW.is_removed = TRUE) THEN
@ -63,19 +70,23 @@ def upgrade():
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER remove_comment_set_removed_time_update CREATE TRIGGER remove_comment_set_removed_time_update
BEFORE UPDATE ON comments BEFORE UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_removed IS DISTINCT FROM NEW.is_removed) WHEN (OLD.is_removed IS DISTINCT FROM NEW.is_removed)
EXECUTE PROCEDURE set_comment_removed_time(); EXECUTE PROCEDURE set_comment_removed_time();
""")
"""
)
# topic_visits # topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments") op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION decrement_all_topic_visit_num_comments()") op.execute("DROP FUNCTION decrement_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_all_topic_visit_num_comments() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION update_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
DECLARE DECLARE
old_visible BOOLEAN := NOT (OLD.is_deleted OR OLD.is_removed); old_visible BOOLEAN := NOT (OLD.is_deleted OR OLD.is_removed);
@ -96,18 +107,22 @@ def upgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed)) OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_all_topic_visit_num_comments(); EXECUTE PROCEDURE update_all_topic_visit_num_comments();
""")
"""
)
# topics # topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN BEGIN
IF (TG_OP = 'INSERT') THEN IF (TG_OP = 'INSERT') THEN
@ -140,18 +155,22 @@ def upgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments") op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed)) OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_num_comments(); EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE DECLARE
most_recent_comment RECORD; most_recent_comment RECORD;
@ -182,31 +201,37 @@ def upgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments") op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed)) OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_last_activity_time(); EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)
def downgrade(): def downgrade():
# comment_notifications # comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments") op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true) WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE delete_comment_notifications(); EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments # comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN BEGIN
NEW.deleted_time := current_timestamp; NEW.deleted_time := current_timestamp;
@ -214,15 +239,18 @@ def downgrade():
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments") op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments BEFORE UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true) WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE set_comment_deleted_time(); EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("DROP TRIGGER remove_comment_set_removed_time_update ON comments") op.execute("DROP TRIGGER remove_comment_set_removed_time_update ON comments")
op.execute("DROP FUNCTION set_comment_removed_time()") op.execute("DROP FUNCTION set_comment_removed_time()")
@ -230,7 +258,8 @@ def downgrade():
# topic_visits # topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments") op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION update_all_topic_visit_num_comments()") op.execute("DROP FUNCTION update_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION decrement_all_topic_visit_num_comments() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION decrement_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
BEGIN BEGIN
UPDATE topic_visits UPDATE topic_visits
@ -241,17 +270,21 @@ def downgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true) WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE decrement_all_topic_visit_num_comments(); EXECUTE PROCEDURE decrement_all_topic_visit_num_comments();
""")
"""
)
# topics # topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN BEGIN
IF (TG_OP = 'INSERT' AND NEW.is_deleted = FALSE) THEN IF (TG_OP = 'INSERT' AND NEW.is_deleted = FALSE) THEN
@ -277,17 +310,21 @@ def downgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments") op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_num_comments(); EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$ CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE DECLARE
most_recent_comment RECORD; most_recent_comment RECORD;
@ -317,12 +354,15 @@ def downgrade():
RETURN NULL; RETURN NULL;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments") op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments AFTER UPDATE ON comments
FOR EACH ROW FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted) WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_last_activity_time(); EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)

21
tildes/consumers/comment_user_mentions_generator.py

@ -13,7 +13,7 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
"""Process a delivered message.""" """Process a delivered message."""
comment = ( comment = (
self.db_session.query(Comment) self.db_session.query(Comment)
.filter_by(comment_id=msg.body['comment_id'])
.filter_by(comment_id=msg.body["comment_id"])
.one() .one()
) )
@ -22,15 +22,16 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
return return
new_mentions = CommentNotification.get_mentions_for_comment( new_mentions = CommentNotification.get_mentions_for_comment(
self.db_session, comment)
self.db_session, comment
)
if msg.delivery_info['routing_key'] == 'comment.created':
if msg.delivery_info["routing_key"] == "comment.created":
for user_mention in new_mentions: for user_mention in new_mentions:
self.db_session.add(user_mention) self.db_session.add(user_mention)
elif msg.delivery_info['routing_key'] == 'comment.edited':
to_delete, to_add = (
CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions))
elif msg.delivery_info["routing_key"] == "comment.edited":
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions
)
for user_mention in to_delete: for user_mention in to_delete:
self.db_session.delete(user_mention) self.db_session.delete(user_mention)
@ -39,8 +40,8 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
self.db_session.add(user_mention) self.db_session.add(user_mention)
if __name__ == '__main__':
if __name__ == "__main__":
CommentUserMentionGenerator( CommentUserMentionGenerator(
queue_name='comment_user_mentions_generator.q',
routing_keys=['comment.created', 'comment.edited'],
queue_name="comment_user_mentions_generator.q",
routing_keys=["comment.created", "comment.edited"],
).consume_queue() ).consume_queue()

29
tildes/consumers/topic_metadata_generator.py

@ -26,9 +26,7 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
def run(self, msg: Message) -> None: def run(self, msg: Message) -> None:
"""Process a delivered message.""" """Process a delivered message."""
topic = ( topic = (
self.db_session.query(Topic)
.filter_by(topic_id=msg.body['topic_id'])
.one()
self.db_session.query(Topic).filter_by(topic_id=msg.body["topic_id"]).one()
) )
if topic.is_text_type: if topic.is_text_type:
@ -42,22 +40,19 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
html_tree = HTMLParser().parseFragment(topic.rendered_html) html_tree = HTMLParser().parseFragment(topic.rendered_html)
# extract the text from all of the HTML elements # extract the text from all of the HTML elements
extracted_text = ''.join(
[element_text for element_text in html_tree.itertext()])
extracted_text = "".join(
[element_text for element_text in html_tree.itertext()]
)
# sanitize unicode, remove leading/trailing whitespace, etc. # sanitize unicode, remove leading/trailing whitespace, etc.
extracted_text = simplify_string(extracted_text) extracted_text = simplify_string(extracted_text)
# create a short excerpt by truncating the simplified string # create a short excerpt by truncating the simplified string
excerpt = truncate_string(
extracted_text,
length=200,
truncate_at_chars=' ',
)
excerpt = truncate_string(extracted_text, length=200, truncate_at_chars=" ")
topic.content_metadata = { topic.content_metadata = {
'word_count': word_count(extracted_text),
'excerpt': excerpt,
"word_count": word_count(extracted_text),
"excerpt": excerpt,
} }
def _generate_link_metadata(self, topic: Topic) -> None: def _generate_link_metadata(self, topic: Topic) -> None:
@ -68,13 +63,11 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
parsed_domain = get_domain_from_url(topic.link) parsed_domain = get_domain_from_url(topic.link)
domain = self.public_suffix_list.get_public_suffix(parsed_domain) domain = self.public_suffix_list.get_public_suffix(parsed_domain)
topic.content_metadata = {
'domain': domain,
}
topic.content_metadata = {"domain": domain}
if __name__ == '__main__':
if __name__ == "__main__":
TopicMetadataGenerator( TopicMetadataGenerator(
queue_name='topic_metadata_generator.q',
routing_keys=['topic.created', 'topic.edited'],
queue_name="topic_metadata_generator.q",
routing_keys=["topic.created", "topic.edited"],
).consume_queue() ).consume_queue()

57
tildes/scripts/breached_passwords.py

@ -46,11 +46,11 @@ def generate_redis_protocol(*elements: Any) -> str:
Based on the example Ruby code from Based on the example Ruby code from
https://redis.io/topics/mass-insert#generating-redis-protocol https://redis.io/topics/mass-insert#generating-redis-protocol
""" """
command = f'*{len(elements)}\r\n'
command = f"*{len(elements)}\r\n"
for element in elements: for element in elements:
element = str(element) element = str(element)
command += f'${len(element)}\r\n{element}\r\n'
command += f"${len(element)}\r\n{element}\r\n"
return command return command
@ -65,27 +65,27 @@ def validate_init_error_rate(ctx: Any, param: Any, value: Any) -> float:
"""Validate the --error-rate arg for the init command.""" """Validate the --error-rate arg for the init command."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
if not 0 < value < 1: if not 0 < value < 1:
raise click.BadParameter('error rate must be a float between 0 and 1')
raise click.BadParameter("error rate must be a float between 0 and 1")
return value return value
@cli.command(help='Initialize a new empty bloom filter')
@cli.command(help="Initialize a new empty bloom filter")
@click.option( @click.option(
'--estimate',
"--estimate",
required=True, required=True,
type=int, type=int,
help='Expected number of passwords that will be added',
help="Expected number of passwords that will be added",
) )
@click.option( @click.option(
'--error-rate',
"--error-rate",
default=0.01, default=0.01,
show_default=True, show_default=True,
help='Bloom filter desired false positive ratio',
help="Bloom filter desired false positive ratio",
callback=validate_init_error_rate, callback=validate_init_error_rate,
) )
@click.confirmation_option( @click.confirmation_option(
prompt='Are you sure you want to clear any existing bloom filter?',
prompt="Are you sure you want to clear any existing bloom filter?"
) )
def init(estimate: int, error_rate: float) -> None: def init(estimate: int, error_rate: float) -> None:
"""Initialize a new bloom filter (destroying any existing one). """Initialize a new bloom filter (destroying any existing one).
@ -102,22 +102,16 @@ def init(estimate: int, error_rate: float) -> None:
REDIS.delete(BREACHED_PASSWORDS_BF_KEY) REDIS.delete(BREACHED_PASSWORDS_BF_KEY)
# BF.RESERVE {key} {error_rate} {size} # BF.RESERVE {key} {error_rate} {size}
REDIS.execute_command(
'BF.RESERVE',
BREACHED_PASSWORDS_BF_KEY,
error_rate,
estimate,
)
REDIS.execute_command("BF.RESERVE", BREACHED_PASSWORDS_BF_KEY, error_rate, estimate)
click.echo( click.echo(
'Initialized bloom filter with expected size of {:,} and false '
'positive rate of {}%'
.format(estimate, error_rate * 100)
"Initialized bloom filter with expected size of {:,} and false "
"positive rate of {}%".format(estimate, error_rate * 100)
) )
@cli.command(help='Add hashes from a file to the bloom filter')
@click.argument('filename', type=click.Path(exists=True, dir_okay=False))
@cli.command(help="Add hashes from a file to the bloom filter")
@click.argument("filename", type=click.Path(exists=True, dir_okay=False))
def addhashes(filename: str) -> None: def addhashes(filename: str) -> None:
"""Add all hashes from a file to the bloom filter. """Add all hashes from a file to the bloom filter.
@ -127,26 +121,26 @@ def addhashes(filename: str) -> None:
""" """
# make sure the key exists and is a bloom filter # make sure the key exists and is a bloom filter
try: try:
REDIS.execute_command('BF.DEBUG', BREACHED_PASSWORDS_BF_KEY)
REDIS.execute_command("BF.DEBUG", BREACHED_PASSWORDS_BF_KEY)
except ResponseError: except ResponseError:
click.echo('Bloom filter is not set up properly - run init first.')
click.echo("Bloom filter is not set up properly - run init first.")
raise click.Abort raise click.Abort
# call wc to count the number of lines in the file for the progress bar # call wc to count the number of lines in the file for the progress bar
click.echo('Determining hash count...')
result = subprocess.run(['wc', '-l', filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b' ')[0])
click.echo("Determining hash count...")
result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b" ")[0])
progress_bar: Any = click.progressbar(length=line_count) progress_bar: Any = click.progressbar(length=line_count)
update_interval = 100_000 update_interval = 100_000
click.echo('Adding {:,} hashes to bloom filter...'.format(line_count))
click.echo("Adding {:,} hashes to bloom filter...".format(line_count))
redis_pipe = subprocess.Popen( redis_pipe = subprocess.Popen(
['redis-cli', '-s', BREACHED_PASSWORDS_REDIS_SOCKET, '--pipe'],
["redis-cli", "-s", BREACHED_PASSWORDS_REDIS_SOCKET, "--pipe"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
encoding='utf-8',
encoding="utf-8",
) )
for count, line in enumerate(open(filename), start=1): for count, line in enumerate(open(filename), start=1):
@ -155,10 +149,9 @@ def addhashes(filename: str) -> None:
# the Pwned Passwords hash lists now have a frequency count for each # the Pwned Passwords hash lists now have a frequency count for each
# hash, which is separated from the hash with a colon, so we need to # hash, which is separated from the hash with a colon, so we need to
# handle that if it's present # handle that if it's present
hashval = hashval.split(':')[0]
hashval = hashval.split(":")[0]
command = generate_redis_protocol(
'BF.ADD', BREACHED_PASSWORDS_BF_KEY, hashval)
command = generate_redis_protocol("BF.ADD", BREACHED_PASSWORDS_BF_KEY, hashval)
redis_pipe.stdin.write(command) redis_pipe.stdin.write(command)
if count % update_interval == 0: if count % update_interval == 0:
@ -173,5 +166,5 @@ def addhashes(filename: str) -> None:
progress_bar.render_finish() progress_bar.render_finish()
if __name__ == '__main__':
if __name__ == "__main__":
cli() cli()

51
tildes/scripts/clean_private_data.py

@ -33,22 +33,17 @@ def clean_all_data(config_path: str) -> None:
cleaner.clean_all() cleaner.clean_all()
class DataCleaner():
class DataCleaner:
"""Container class for all methods related to cleaning up old data.""" """Container class for all methods related to cleaning up old data."""
def __init__(
self,
db_session: Session,
retention_period: timedelta,
) -> None:
def __init__(self, db_session: Session, retention_period: timedelta) -> None:
"""Create a new DataCleaner.""" """Create a new DataCleaner."""
self.db_session = db_session self.db_session = db_session
self.retention_cutoff = datetime.now() - retention_period self.retention_cutoff = datetime.now() - retention_period
def clean_all(self) -> None: def clean_all(self) -> None:
"""Call all the cleanup functions.""" """Call all the cleanup functions."""
logging.info(
f'Cleaning up all data (retention cutoff {self.retention_cutoff})')
logging.info(f"Cleaning up all data (retention cutoff {self.retention_cutoff})")
self.delete_old_log_entries() self.delete_old_log_entries()
self.delete_old_topic_visits() self.delete_old_topic_visits()
@ -68,7 +63,7 @@ class DataCleaner():
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
self.db_session.commit() self.db_session.commit()
logging.info(f'Deleted {deleted} old log entries.')
logging.info(f"Deleted {deleted} old log entries.")
def delete_old_topic_visits(self) -> None: def delete_old_topic_visits(self) -> None:
"""Delete all topic visits older than the retention cutoff.""" """Delete all topic visits older than the retention cutoff."""
@ -78,7 +73,7 @@ class DataCleaner():
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )
self.db_session.commit() self.db_session.commit()
logging.info(f'Deleted {deleted} old topic visits.')
logging.info(f"Deleted {deleted} old topic visits.")
def clean_old_deleted_comments(self) -> None: def clean_old_deleted_comments(self) -> None:
"""Clean the data of old deleted comments. """Clean the data of old deleted comments.
@ -92,14 +87,13 @@ class DataCleaner():
Comment.deleted_time <= self.retention_cutoff, # type: ignore Comment.deleted_time <= self.retention_cutoff, # type: ignore
Comment.user_id != 0, Comment.user_id != 0,
) )
.update({
'user_id': 0,
'markdown': '',
'rendered_html': '',
}, synchronize_session=False)
.update(
{"user_id": 0, "markdown": "", "rendered_html": ""},
synchronize_session=False,
)
) )
self.db_session.commit() self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted comments.')
logging.info(f"Cleaned {updated} old deleted comments.")
def clean_old_deleted_topics(self) -> None: def clean_old_deleted_topics(self) -> None:
"""Clean the data of old deleted topics. """Clean the data of old deleted topics.
@ -113,16 +107,19 @@ class DataCleaner():
Topic.deleted_time <= self.retention_cutoff, # type: ignore Topic.deleted_time <= self.retention_cutoff, # type: ignore
Topic.user_id != 0, Topic.user_id != 0,
) )
.update({
'user_id': 0,
'title': '',
'topic_type': 'TEXT',
'markdown': None,
'rendered_html': None,
'link': None,
'content_metadata': None,
'_tags': [],
}, synchronize_session=False)
.update(
{
"user_id": 0,
"title": "",
"topic_type": "TEXT",
"markdown": None,
"rendered_html": None,
"link": None,
"content_metadata": None,
"_tags": [],
},
synchronize_session=False,
)
) )
self.db_session.commit() self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted topics.')
logging.info(f"Cleaned {updated} old deleted topics.")

50
tildes/scripts/initialize_db.py

@ -15,27 +15,24 @@ from tildes.models.log import Log
from tildes.models.user import User from tildes.models.user import User
def initialize_db(
config_path: str,
alembic_config_path: Optional[str] = None,
) -> None:
def initialize_db(config_path: str, alembic_config_path: Optional[str] = None) -> None:
"""Load the app config and create the database tables.""" """Load the app config and create the database tables."""
db_session = get_session_from_config(config_path) db_session = get_session_from_config(config_path)
engine = db_session.bind engine = db_session.bind
create_tables(engine) create_tables(engine)
run_sql_scripts_in_dir('sql/init/', engine)
run_sql_scripts_in_dir("sql/init/", engine)
# if an Alembic config file wasn't specified, assume it's alembic.ini in # if an Alembic config file wasn't specified, assume it's alembic.ini in
# the same directory # the same directory
if not alembic_config_path: if not alembic_config_path:
path = os.path.split(config_path)[0] path = os.path.split(config_path)[0]
alembic_config_path = os.path.join(path, 'alembic.ini')
alembic_config_path = os.path.join(path, "alembic.ini")
# mark current Alembic revision in db so migrations start from this point # mark current Alembic revision in db so migrations start from this point
alembic_cfg = Config(alembic_config_path) alembic_cfg = Config(alembic_config_path)
command.stamp(alembic_cfg, 'head')
command.stamp(alembic_cfg, "head")
def create_tables(connectable: Connectable) -> None: def create_tables(connectable: Connectable) -> None:
@ -44,7 +41,8 @@ def create_tables(connectable: Connectable) -> None:
excluded_tables = Log.INHERITED_TABLES excluded_tables = Log.INHERITED_TABLES
tables = [ tables = [
table for table in DatabaseModel.metadata.tables.values()
table
for table in DatabaseModel.metadata.tables.values()
if table.name not in excluded_tables if table.name not in excluded_tables
] ]
DatabaseModel.metadata.create_all(connectable, tables=tables) DatabaseModel.metadata.create_all(connectable, tables=tables)
@ -53,29 +51,31 @@ def create_tables(connectable: Connectable) -> None:
def run_sql_scripts_in_dir(path: str, engine: Engine) -> None: def run_sql_scripts_in_dir(path: str, engine: Engine) -> None:
"""Run all sql scripts in a directory.""" """Run all sql scripts in a directory."""
for root, _, files in os.walk(path): for root, _, files in os.walk(path):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files: for sql_file in sql_files:
subprocess.call([
'psql',
'-U', engine.url.username,
'-f', os.path.join(root, sql_file),
engine.url.database,
])
subprocess.call(
[
"psql",
"-U",
engine.url.username,
"-f",
os.path.join(root, sql_file),
engine.url.database,
]
)
def insert_dev_data(config_path: str) -> None: def insert_dev_data(config_path: str) -> None:
"""Load the app config and insert some "starter" data for a dev version.""" """Load the app config and insert some "starter" data for a dev version."""
session = get_session_from_config(config_path) session = get_session_from_config(config_path)
session.add_all([
User('TestUser', 'password'),
Group(
'testing',
'An automatically created group to use for testing purposes',
),
])
session.add_all(
[
User("TestUser", "password"),
Group(
"testing", "An automatically created group to use for testing purposes"
),
]
)
session.commit() session.commit()

6
tildes/setup.py

@ -4,11 +4,11 @@ from setuptools import find_packages, setup
setup( setup(
name='tildes',
version='0.1',
name="tildes",
version="0.1",
packages=find_packages(), packages=find_packages(),
entry_points=""" entry_points="""
[paste.app_factory] [paste.app_factory]
main = tildes:main main = tildes:main
"""
""",
) )

71
tildes/tests/conftest.py

@ -18,7 +18,7 @@ from tildes.models.user import User
# include the fixtures defined in fixtures.py # include the fixtures defined in fixtures.py
pytest_plugins = ['tests.fixtures']
pytest_plugins = ["tests.fixtures"]
class NestedSessionWrapper(Session): class NestedSessionWrapper(Session):
@ -40,25 +40,25 @@ class NestedSessionWrapper(Session):
super().rollback() super().rollback()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def pyramid_config(): def pyramid_config():
"""Set up the Pyramid environment.""" """Set up the Pyramid environment."""
settings = get_appsettings('development.ini')
settings = get_appsettings("development.ini")
config = testing.setUp(settings=settings) config = testing.setUp(settings=settings)
config.include('tildes.auth')
config.include("tildes.auth")
yield config yield config
testing.tearDown() testing.tearDown()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_db_session(pyramid_config): def overall_db_session(pyramid_config):
"""Handle setup and teardown of test database for testing session.""" """Handle setup and teardown of test database for testing session."""
# read the database url from the pyramid INI file, and replace the db name # read the database url from the pyramid INI file, and replace the db name
sqlalchemy_url = pyramid_config.registry.settings['sqlalchemy.url']
sqlalchemy_url = pyramid_config.registry.settings["sqlalchemy.url"]
parsed_url = make_url(sqlalchemy_url) parsed_url = make_url(sqlalchemy_url)
parsed_url.database = 'tildes_test'
parsed_url.database = "tildes_test"
engine = create_engine(parsed_url) engine = create_engine(parsed_url)
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
@ -69,12 +69,9 @@ def overall_db_session(pyramid_config):
# SQL init scripts need to be executed "manually" instead of using psql # SQL init scripts need to be executed "manually" instead of using psql
# like the normal database init process does, since the tables only exist # like the normal database init process does, since the tables only exist
# inside this transaction # inside this transaction
init_scripts_dir = 'sql/init/'
init_scripts_dir = "sql/init/"
for root, _, files in os.walk(init_scripts_dir): for root, _, files in os.walk(init_scripts_dir):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files: for sql_file in sql_files:
with open(os.path.join(root, sql_file)) as current_file: with open(os.path.join(root, sql_file)) as current_file:
session.execute(current_file.read()) session.execute(current_file.read())
@ -90,7 +87,7 @@ def overall_db_session(pyramid_config):
session.rollback() session.rollback()
@fixture(scope='session')
@fixture(scope="session")
def sdb(overall_db_session): def sdb(overall_db_session):
"""Testing-session-level db session with a nested transaction.""" """Testing-session-level db session with a nested transaction."""
overall_db_session.begin_nested() overall_db_session.begin_nested()
@ -100,7 +97,7 @@ def sdb(overall_db_session):
overall_db_session.rollback_all_nested() overall_db_session.rollback_all_nested()
@fixture(scope='function')
@fixture(scope="function")
def db(overall_db_session): def db(overall_db_session):
"""Function-level db session with a nested transaction.""" """Function-level db session with a nested transaction."""
overall_db_session.begin_nested() overall_db_session.begin_nested()
@ -110,25 +107,23 @@ def db(overall_db_session):
overall_db_session.rollback_all_nested() overall_db_session.rollback_all_nested()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_redis_session(): def overall_redis_session():
"""Create a session-level connection to a temporary redis server.""" """Create a session-level connection to a temporary redis server."""
# list of redis modules that need to be loaded (would be much nicer to do # list of redis modules that need to be loaded (would be much nicer to do
# this automatically somehow, maybe reading from the real redis.conf?) # this automatically somehow, maybe reading from the real redis.conf?)
redis_modules = [
'/opt/redis-cell/libredis_cell.so',
]
redis_modules = ["/opt/redis-cell/libredis_cell.so"]
with RedisServer() as temp_redis_server: with RedisServer() as temp_redis_server:
redis = StrictRedis(**temp_redis_server.dsn()) redis = StrictRedis(**temp_redis_server.dsn())
for module in redis_modules: for module in redis_modules:
redis.execute_command('MODULE LOAD', module)
redis.execute_command("MODULE LOAD", module)
yield redis yield redis
@fixture(scope='function')
@fixture(scope="function")
def redis(overall_redis_session): def redis(overall_redis_session):
"""Create a function-level redis connection that wipes the db after use.""" """Create a function-level redis connection that wipes the db after use."""
yield overall_redis_session yield overall_redis_session
@ -136,47 +131,47 @@ def redis(overall_redis_session):
overall_redis_session.flushdb() overall_redis_session.flushdb()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user(sdb): def session_user(sdb):
"""Create a user named 'SessionUser' in the db for test session.""" """Create a user named 'SessionUser' in the db for test session."""
# note that some tests may depend on this username/password having these # note that some tests may depend on this username/password having these
# specific values, so make sure to search for and update those tests if you # specific values, so make sure to search for and update those tests if you
# change the username or password for any reason # change the username or password for any reason
user = User('SessionUser', 'session user password')
user = User("SessionUser", "session user password")
sdb.add(user) sdb.add(user)
sdb.commit() sdb.commit()
yield user yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user2(sdb): def session_user2(sdb):
"""Create a second user named 'OtherUser' in the db for test session. """Create a second user named 'OtherUser' in the db for test session.
This is useful for cases where two different users are needed, such as This is useful for cases where two different users are needed, such as
when testing private messages. when testing private messages.
""" """
user = User('OtherUser', 'other user password')
user = User("OtherUser", "other user password")
sdb.add(user) sdb.add(user)
sdb.commit() sdb.commit()
yield user yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_group(sdb): def session_group(sdb):
"""Create a group named 'sessiongroup' in the db for test session.""" """Create a group named 'sessiongroup' in the db for test session."""
group = Group('sessiongroup')
group = Group("sessiongroup")
sdb.add(group) sdb.add(group)
sdb.commit() sdb.commit()
yield group yield group
@fixture(scope='session')
@fixture(scope="session")
def base_app(overall_redis_session, sdb): def base_app(overall_redis_session, sdb):
"""Configure a base WSGI app that webtest can create TestApps based on.""" """Configure a base WSGI app that webtest can create TestApps based on."""
testing_app = get_app('development.ini')
testing_app = get_app("development.ini")
# replace the redis connection used by the redis-sessions library with a # replace the redis connection used by the redis-sessions library with a
# connection to the temporary server for this test session # connection to the temporary server for this test session
@ -185,38 +180,38 @@ def base_app(overall_redis_session, sdb):
def redis_factory(request): def redis_factory(request):
# pylint: disable=unused-argument # pylint: disable=unused-argument
return overall_redis_session return overall_redis_session
testing_app.app.registry['redis_connection_factory'] = redis_factory
testing_app.app.registry["redis_connection_factory"] = redis_factory
# replace the session factory function with one that will return the # replace the session factory function with one that will return the
# testing db session (inside a nested transaction) # testing db session (inside a nested transaction)
def session_factory(): def session_factory():
return sdb return sdb
testing_app.app.registry['db_session_factory'] = session_factory
testing_app.app.registry["db_session_factory"] = session_factory
yield testing_app yield testing_app
@fixture(scope='session')
@fixture(scope="session")
def webtest(base_app): def webtest(base_app):
"""Create a webtest TestApp and log in as the SessionUser account in it.""" """Create a webtest TestApp and log in as the SessionUser account in it."""
# create the TestApp - note that specifying wsgi.url_scheme is necessary # create the TestApp - note that specifying wsgi.url_scheme is necessary
# so that the secure cookies from the session library will work # so that the secure cookies from the session library will work
app = TestApp( app = TestApp(
base_app,
extra_environ={'wsgi.url_scheme': 'https'},
cookiejar=CookieJar(),
base_app, extra_environ={"wsgi.url_scheme": "https"}, cookiejar=CookieJar()
) )
# fetch the login page, fill in the form, and submit it (sets the cookie) # fetch the login page, fill in the form, and submit it (sets the cookie)
login_page = app.get('/login')
login_page.form['username'] = 'SessionUser'
login_page.form['password'] = 'session user password'
login_page = app.get("/login")
login_page.form["username"] = "SessionUser"
login_page.form["password"] = "session user password"
login_page.form.submit() login_page.form.submit()
yield app yield app
@fixture(scope='session')
@fixture(scope="session")
def webtest_loggedout(base_app): def webtest_loggedout(base_app):
"""Create a logged-out webtest TestApp (no cookies retained).""" """Create a logged-out webtest TestApp (no cookies retained)."""
yield TestApp(base_app) yield TestApp(base_app)

8
tildes/tests/fixtures.py

@ -8,7 +8,8 @@ from tildes.models.topic import Topic
def text_topic(db, session_group, session_user): def text_topic(db, session_group, session_user):
"""Create a text topic, delete it as teardown (including comments).""" """Create a text topic, delete it as teardown (including comments)."""
new_topic = Topic.create_text_topic( new_topic = Topic.create_text_topic(
session_group, session_user, 'A Text Topic', 'the text')
session_group, session_user, "A Text Topic", "the text"
)
db.add(new_topic) db.add(new_topic)
db.commit() db.commit()
@ -23,7 +24,8 @@ def text_topic(db, session_group, session_user):
def link_topic(db, session_group, session_user): def link_topic(db, session_group, session_user):
"""Create a link topic, delete it as teardown (including comments).""" """Create a link topic, delete it as teardown (including comments)."""
new_topic = Topic.create_link_topic( new_topic = Topic.create_link_topic(
session_group, session_user, 'A Link Topic', 'http://example.com')
session_group, session_user, "A Link Topic", "http://example.com"
)
db.add(new_topic) db.add(new_topic)
db.commit() db.commit()
@ -43,7 +45,7 @@ def topic(text_topic):
@fixture @fixture
def comment(db, session_user, topic): def comment(db, session_user, topic):
"""Create a comment in the database, delete it as teardown.""" """Create a comment in the database, delete it as teardown."""
new_comment = Comment(topic, session_user, 'A comment')
new_comment = Comment(topic, session_user, "A comment")
db.add(new_comment) db.add(new_comment)
db.commit() db.commit()

64
tildes/tests/test_comment.py

@ -1,81 +1,73 @@
from datetime import timedelta from datetime import timedelta
from freezegun import freeze_time from freezegun import freeze_time
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
from tildes.enums import CommentSortOption from tildes.enums import CommentSortOption
from tildes.lib.datetime import utc_now from tildes.lib.datetime import utc_now
from tildes.models.comment import (
Comment,
CommentTree,
EDIT_GRACE_PERIOD,
)
from tildes.models.comment import Comment, CommentTree, EDIT_GRACE_PERIOD
from tildes.schemas.comment import CommentSchema from tildes.schemas.comment import CommentSchema
from tildes.schemas.fields import Markdown from tildes.schemas.fields import Markdown
def test_comment_creation_validates_schema(mocker, session_user, topic): def test_comment_creation_validates_schema(mocker, session_user, topic):
"""Ensure that comment creation goes through schema validation.""" """Ensure that comment creation goes through schema validation."""
mocker.spy(CommentSchema, 'load')
mocker.spy(CommentSchema, "load")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
call_args = CommentSchema.load.call_args[0] call_args = CommentSchema.load.call_args[0]
assert {'markdown': 'A test comment'} in call_args
assert {"markdown": "A test comment"} in call_args
def test_comment_creation_uses_markdown_field(mocker, session_user, topic): def test_comment_creation_uses_markdown_field(mocker, session_user, topic):
"""Ensure the Markdown field class is validating new comments.""" """Ensure the Markdown field class is validating new comments."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
assert Markdown._validate.called assert Markdown._validate.called
def test_comment_edit_uses_markdown_field(mocker, comment): def test_comment_edit_uses_markdown_field(mocker, comment):
"""Ensure editing a comment is validated by the Markdown field class.""" """Ensure editing a comment is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
comment.markdown = 'Some new text after edit'
comment.markdown = "Some new text after edit"
assert Markdown._validate.called assert Markdown._validate.called
def test_edit_markdown_updates_html(comment): def test_edit_markdown_updates_html(comment):
"""Ensure editing a comment works and the markdown and HTML update.""" """Ensure editing a comment works and the markdown and HTML update."""
comment.markdown = 'Updated comment'
assert 'Updated' in comment.markdown
assert 'Updated' in comment.rendered_html
comment.markdown = "Updated comment"
assert "Updated" in comment.markdown
assert "Updated" in comment.rendered_html
def test_comment_viewing_permission(comment): def test_comment_viewing_permission(comment):
"""Ensure that anyone can view a comment by default.""" """Ensure that anyone can view a comment by default."""
assert Everyone in principals_allowed_by_permission(comment, 'view')
assert Everyone in principals_allowed_by_permission(comment, "view")
def test_comment_editing_permission(comment): def test_comment_editing_permission(comment):
"""Ensure that only the comment's author can edit it.""" """Ensure that only the comment's author can edit it."""
principals = principals_allowed_by_permission(comment, 'edit')
principals = principals_allowed_by_permission(comment, "edit")
assert principals == {comment.user_id} assert principals == {comment.user_id}
def test_comment_deleting_permission(comment): def test_comment_deleting_permission(comment):
"""Ensure that only the comment's author can delete it.""" """Ensure that only the comment's author can delete it."""
principals = principals_allowed_by_permission(comment, 'delete')
principals = principals_allowed_by_permission(comment, "delete")
assert principals == {comment.user_id} assert principals == {comment.user_id}
def test_comment_replying_permission(comment): def test_comment_replying_permission(comment):
"""Ensure that any authenticated user can reply to a comment.""" """Ensure that any authenticated user can reply to a comment."""
assert Authenticated in principals_allowed_by_permission(comment, 'reply')
assert Authenticated in principals_allowed_by_permission(comment, "reply")
def test_comment_reply_locked_thread_permission(comment): def test_comment_reply_locked_thread_permission(comment):
"""Ensure that only admins can reply in locked threads.""" """Ensure that only admins can reply in locked threads."""
comment.topic.is_locked = True comment.topic.is_locked = True
assert principals_allowed_by_permission(comment, 'reply') == {'admin'}
assert principals_allowed_by_permission(comment, "reply") == {"admin"}
def test_deleted_comment_permissions_removed(comment): def test_deleted_comment_permissions_removed(comment):
@ -90,8 +82,8 @@ def test_deleted_comment_permissions_removed(comment):
def test_removed_comment_view_permission(comment): def test_removed_comment_view_permission(comment):
"""Ensure a removed comment can only be viewed by its author and admins.""" """Ensure a removed comment can only be viewed by its author and admins."""
comment.is_removed = True comment.is_removed = True
principals = principals_allowed_by_permission(comment, 'view')
assert principals == {'admin', comment.user_id}
principals = principals_allowed_by_permission(comment, "view")
assert principals == {"admin", comment.user_id}
def test_edit_grace_period(comment): def test_edit_grace_period(comment):
@ -100,7 +92,7 @@ def test_edit_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD - one_sec edit_time = comment.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time): with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert not comment.last_edited_time assert not comment.last_edited_time
@ -111,7 +103,7 @@ def test_edit_after_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD + one_sec edit_time = comment.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time): with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert comment.last_edited_time == utc_now() assert comment.last_edited_time == utc_now()
@ -123,7 +115,7 @@ def test_multiple_edits_update_time(comment):
for minutes in range(0, 4): for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes) edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time): with freeze_time(edit_time):
comment.markdown = f'edit #{minutes}'
comment.markdown = f"edit #{minutes}"
assert comment.last_edited_time == utc_now() assert comment.last_edited_time == utc_now()
@ -134,8 +126,8 @@ def test_comment_tree(db, topic, session_user):
sort = CommentSortOption.POSTED sort = CommentSortOption.POSTED
# add two root comments # add two root comments
root = Comment(topic, session_user, 'root')
root2 = Comment(topic, session_user, 'root2')
root = Comment(topic, session_user, "root")
root2 = Comment(topic, session_user, "root2")
all_comments.extend([root, root2]) all_comments.extend([root, root2])
db.add_all(all_comments) db.add_all(all_comments)
db.commit() db.commit()
@ -151,8 +143,8 @@ def test_comment_tree(db, topic, session_user):
assert tree == [root] assert tree == [root]
# add two replies to the remaining root comment # add two replies to the remaining root comment
child = Comment(topic, session_user, '1', parent_comment=root)
child2 = Comment(topic, session_user, '2', parent_comment=root)
child = Comment(topic, session_user, "1", parent_comment=root)
child2 = Comment(topic, session_user, "2", parent_comment=root)
all_comments.extend([child, child2]) all_comments.extend([child, child2])
db.add_all(all_comments) db.add_all(all_comments)
db.commit() db.commit()
@ -165,8 +157,8 @@ def test_comment_tree(db, topic, session_user):
assert child2.replies == [] assert child2.replies == []
# add two more replies to the second depth-1 comment # add two more replies to the second depth-1 comment
subchild = Comment(topic, session_user, '2a', parent_comment=child2)
subchild2 = Comment(topic, session_user, '2b', parent_comment=child2)
subchild = Comment(topic, session_user, "2a", parent_comment=child2)
subchild2 = Comment(topic, session_user, "2b", parent_comment=child2)
all_comments.extend([subchild, subchild2]) all_comments.extend([subchild, subchild2])
db.add_all(all_comments) db.add_all(all_comments)
db.commit() db.commit()

72
tildes/tests/test_comment_user_mentions.py

@ -3,10 +3,7 @@ from pytest import fixture
from sqlalchemy import and_ from sqlalchemy import and_
from tildes.enums import CommentNotificationType from tildes.enums import CommentNotificationType
from tildes.models.comment import (
Comment,
CommentNotification,
)
from tildes.models.comment import Comment, CommentNotification
from tildes.models.topic import Topic from tildes.models.topic import Topic
from tildes.models.user import User from tildes.models.user import User
@ -15,8 +12,8 @@ from tildes.models.user import User
def user_list(db): def user_list(db):
"""Create several users.""" """Create several users."""
users = [] users = []
for name in ['foo', 'bar', 'baz']:
user = User(name, 'password')
for name in ["foo", "bar", "baz"]:
user = User(name, "password")
users.append(user) users.append(user)
db.add(user) db.add(user)
db.commit() db.commit()
@ -30,44 +27,40 @@ def user_list(db):
def test_get_mentions_for_comment(db, user_list, comment): def test_get_mentions_for_comment(db, user_list, comment):
"""Test that notifications are generated and returned.""" """Test that notifications are generated and returned."""
comment.markdown = '@foo @bar. @baz!'
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment.markdown = "@foo @bar. @baz!"
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 3 assert len(mentions) == 3
for index, user in enumerate(user_list): for index, user in enumerate(user_list):
assert mentions[index].user == user assert mentions[index].user == user
def test_mention_filtering_parent_comment(
mocker, db, topic, user_list):
def test_mention_filtering_parent_comment(mocker, db, topic, user_list):
"""Test notification filtering for parent comments.""" """Test notification filtering for parent comments."""
parent_comment = Comment(topic, user_list[0], 'Comment content.')
parent_comment = Comment(topic, user_list[0], "Comment content.")
parent_comment.user_id = user_list[0].user_id parent_comment.user_id = user_list[0].user_id
comment = mocker.Mock( comment = mocker.Mock(
user_id=user_list[1].user_id, user_id=user_list[1].user_id,
markdown=f'@{user_list[0].username}',
markdown=f"@{user_list[0].username}",
parent_comment=parent_comment, parent_comment=parent_comment,
) )
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions assert not mentions
def test_mention_filtering_self_mention(db, user_list, topic): def test_mention_filtering_self_mention(db, user_list, topic):
"""Test notification filtering for self-mentions.""" """Test notification filtering for self-mentions."""
comment = Comment(topic, user_list[0], f'@{user_list[0]}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment = Comment(topic, user_list[0], f"@{user_list[0]}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions assert not mentions
def test_mention_filtering_top_level(db, user_list, session_group): def test_mention_filtering_top_level(db, user_list, session_group):
"""Test notification filtering for top-level comments.""" """Test notification filtering for top-level comments."""
topic = Topic.create_text_topic( topic = Topic.create_text_topic(
session_group, user_list[0], 'Some title', 'some text')
comment = Comment(topic, user_list[1], f'@{user_list[0].username}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
session_group, user_list[0], "Some title", "some text"
)
comment = Comment(topic, user_list[1], f"@{user_list[0].username}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions assert not mentions
@ -82,36 +75,35 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
4. The comment is deleted. 4. The comment is deleted.
""" """
# 1 # 1
comment = Comment(topic, user_list[0], f'@{user_list[1].username}')
comment = Comment(topic, user_list[0], f"@{user_list[1].username}")
db.add(comment) db.add(comment)
db.commit() db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1 assert len(mentions) == 1
assert mentions[0].user == user_list[1] assert mentions[0].user == user_list[1]
db.add_all(mentions) db.add_all(mentions)
db.commit() db.commit()
# 2 # 2
comment.markdown = f'@{user_list[2].username}'
comment.markdown = f"@{user_list[2].username}"
db.commit() db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1 assert len(mentions) == 1
to_delete, to_add = CommentNotification.prevent_duplicate_notifications( to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert len(to_delete) == 1 assert len(to_delete) == 1
assert mentions == to_add assert mentions == to_add
assert to_delete[0].user.username == user_list[1].username assert to_delete[0].user.username == user_list[1].username
# 3 # 3
comment.markdown = f'@{user_list[1].username} @{user_list[2].username}'
comment.markdown = f"@{user_list[1].username} @{user_list[2].username}"
db.commit() db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 2 assert len(mentions) == 2
to_delete, to_add = CommentNotification.prevent_duplicate_notifications( to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert not to_delete assert not to_delete
assert len(to_add) == 1 assert len(to_add) == 1
@ -120,9 +112,13 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
db.commit() db.commit()
notifications = ( notifications = (
db.query(CommentNotification.user_id) db.query(CommentNotification.user_id)
.filter(and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
)).all())
.filter(
and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
)
.all()
)
assert not notifications assert not notifications

14
tildes/tests/test_datetime.py

@ -20,40 +20,40 @@ def test_utc_now_accurate():
def test_descriptive_timedelta_basic(): def test_descriptive_timedelta_basic():
"""Ensure a simple descriptive timedelta works correctly.""" """Ensure a simple descriptive timedelta works correctly."""
test_time = utc_now() - timedelta(hours=3) test_time = utc_now() - timedelta(hours=3)
assert descriptive_timedelta(test_time) == '3 hours ago'
assert descriptive_timedelta(test_time) == "3 hours ago"
def test_more_precise_longer_descriptive_timedelta(): def test_more_precise_longer_descriptive_timedelta():
"""Ensure a longer time period gets the extra precision level.""" """Ensure a longer time period gets the extra precision level."""
test_time = utc_now() - timedelta(days=2, hours=5) test_time = utc_now() - timedelta(days=2, hours=5)
assert descriptive_timedelta(test_time) == '2 days, 5 hours ago'
assert descriptive_timedelta(test_time) == "2 days, 5 hours ago"
def test_no_small_precision_descriptive_timedelta(): def test_no_small_precision_descriptive_timedelta():
"""Ensure the extra precision doesn't apply to small units.""" """Ensure the extra precision doesn't apply to small units."""
test_time = utc_now() - timedelta(days=6, minutes=10) test_time = utc_now() - timedelta(days=6, minutes=10)
assert descriptive_timedelta(test_time) == '6 days ago'
assert descriptive_timedelta(test_time) == "6 days ago"
def test_single_precision_below_an_hour(): def test_single_precision_below_an_hour():
"""Ensure times under an hour only have one precision level.""" """Ensure times under an hour only have one precision level."""
test_time = utc_now() - timedelta(minutes=59, seconds=59) test_time = utc_now() - timedelta(minutes=59, seconds=59)
assert descriptive_timedelta(test_time) == '59 minutes ago'
assert descriptive_timedelta(test_time) == "59 minutes ago"
def test_more_precision_above_an_hour(): def test_more_precision_above_an_hour():
"""Ensure the second precision level gets added just above an hour.""" """Ensure the second precision level gets added just above an hour."""
test_time = utc_now() - timedelta(hours=1, minutes=1) test_time = utc_now() - timedelta(hours=1, minutes=1)
assert descriptive_timedelta(test_time) == '1 hour, 1 minute ago'
assert descriptive_timedelta(test_time) == "1 hour, 1 minute ago"
def test_subsecond_descriptive_timedelta(): def test_subsecond_descriptive_timedelta():
"""Ensure time less than a second returns the special phrase.""" """Ensure time less than a second returns the special phrase."""
test_time = utc_now() - timedelta(microseconds=100) test_time = utc_now() - timedelta(microseconds=100)
assert descriptive_timedelta(test_time) == 'a moment ago'
assert descriptive_timedelta(test_time) == "a moment ago"
def test_above_second_descriptive_timedelta(): def test_above_second_descriptive_timedelta():
"""Ensure it starts describing time in seconds above 1 second.""" """Ensure it starts describing time in seconds above 1 second."""
test_time = utc_now() - timedelta(seconds=1, microseconds=100) test_time = utc_now() - timedelta(seconds=1, microseconds=100)
assert descriptive_timedelta(test_time) == '1 second ago'
assert descriptive_timedelta(test_time) == "1 second ago"

41
tildes/tests/test_group.py

@ -3,46 +3,43 @@ from sqlalchemy.exc import IntegrityError
from tildes.models.group import Group from tildes.models.group import Group
from tildes.schemas.fields import Ltree, SimpleString from tildes.schemas.fields import Ltree, SimpleString
from tildes.schemas.group import (
GroupSchema,
is_valid_group_path,
)
from tildes.schemas.group import GroupSchema, is_valid_group_path
def test_empty_path_invalid(): def test_empty_path_invalid():
"""Ensure empty group path is invalid.""" """Ensure empty group path is invalid."""
assert not is_valid_group_path('')
assert not is_valid_group_path("")
def test_typical_path_valid(): def test_typical_path_valid():
"""Ensure a "normal-looking" group path is valid.""" """Ensure a "normal-looking" group path is valid."""
assert is_valid_group_path('games.video.nintendo_3ds')
assert is_valid_group_path("games.video.nintendo_3ds")
def test_start_with_underscore(): def test_start_with_underscore():
"""Ensure you can't start a path with an underscore.""" """Ensure you can't start a path with an underscore."""
assert not is_valid_group_path('_x.y.z')
assert not is_valid_group_path("_x.y.z")
def test_middle_element_start_with_underscore(): def test_middle_element_start_with_underscore():
"""Ensure a middle path element can't start with an underscore.""" """Ensure a middle path element can't start with an underscore."""
assert not is_valid_group_path('x._y.z')
assert not is_valid_group_path("x._y.z")
def test_end_with_underscore(): def test_end_with_underscore():
"""Ensure you can't end a path with an underscore.""" """Ensure you can't end a path with an underscore."""
assert not is_valid_group_path('x.y.z_')
assert not is_valid_group_path("x.y.z_")
def test_middle_element_end_with_underscore(): def test_middle_element_end_with_underscore():
"""Ensure a middle path element can't end with an underscore.""" """Ensure a middle path element can't end with an underscore."""
assert not is_valid_group_path('x.y_.z')
assert not is_valid_group_path("x.y_.z")
def test_uppercase_letters_invalid(): def test_uppercase_letters_invalid():
"""Ensure a group path can't contain uppercase chars.""" """Ensure a group path can't contain uppercase chars."""
assert is_valid_group_path('comp.lang.c')
assert not is_valid_group_path('comp.lang.C')
assert is_valid_group_path("comp.lang.c")
assert not is_valid_group_path("comp.lang.C")
def test_paths_with_invalid_characters(): def test_paths_with_invalid_characters():
@ -50,34 +47,34 @@ def test_paths_with_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,?/' invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,?/'
for char in invalid_chars: for char in invalid_chars:
path = f'abc{char}xyz'
path = f"abc{char}xyz"
assert not is_valid_group_path(path) assert not is_valid_group_path(path)
def test_paths_with_unicode_characters(): def test_paths_with_unicode_characters():
"""Ensure that paths can't use unicode chars (not comprehensive).""" """Ensure that paths can't use unicode chars (not comprehensive)."""
for path in ('games.pokémon', 'ポケモン', 'bites.møøse'):
for path in ("games.pokémon", "ポケモン", "bites.møøse"):
assert not is_valid_group_path(path) assert not is_valid_group_path(path)
def test_creation_validates_schema(mocker): def test_creation_validates_schema(mocker):
"""Ensure that group creation goes through expected validation.""" """Ensure that group creation goes through expected validation."""
mocker.spy(GroupSchema, 'load')
mocker.spy(Ltree, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(GroupSchema, "load")
mocker.spy(Ltree, "_validate")
mocker.spy(SimpleString, "_validate")
Group('testing', 'with a short description')
Group("testing", "with a short description")
assert GroupSchema.load.called assert GroupSchema.load.called
assert Ltree._validate.call_args[0][1] == 'testing'
assert SimpleString._validate.call_args[0][1] == 'with a short description'
assert Ltree._validate.call_args[0][1] == "testing"
assert SimpleString._validate.call_args[0][1] == "with a short description"
def test_duplicate_group(db): def test_duplicate_group(db):
"""Ensure groups with duplicate paths can't be created.""" """Ensure groups with duplicate paths can't be created."""
original = Group('twins')
original = Group("twins")
db.add(original) db.add(original)
duplicate = Group('twins')
duplicate = Group("twins")
db.add(duplicate) db.add(duplicate)
with raises(IntegrityError): with raises(IntegrityError):

10
tildes/tests/test_id.py

@ -7,12 +7,12 @@ from tildes.lib.id import id_to_id36, id36_to_id
def test_id_to_id36(): def test_id_to_id36():
"""Make sure an ID->ID36 conversion is correct.""" """Make sure an ID->ID36 conversion is correct."""
assert id_to_id36(571049189) == '9fzkdh'
assert id_to_id36(571049189) == "9fzkdh"
def test_id36_to_id(): def test_id36_to_id():
"""Make sure an ID36->ID conversion is correct.""" """Make sure an ID36->ID conversion is correct."""
assert id36_to_id('x48l4z') == 2002502915
assert id36_to_id("x48l4z") == 2002502915
def test_reversed_conversion_from_id(): def test_reversed_conversion_from_id():
@ -23,7 +23,7 @@ def test_reversed_conversion_from_id():
def test_reversed_conversion_from_id36(): def test_reversed_conversion_from_id36():
"""Make sure an ID36->ID->ID36 conversion returns to original value.""" """Make sure an ID36->ID->ID36 conversion returns to original value."""
original = 'h2l4pe'
original = "h2l4pe"
assert id_to_id36(id36_to_id(original)) == original assert id_to_id36(id36_to_id(original)) == original
@ -36,7 +36,7 @@ def test_zero_id_conversion_blocked():
def test_zero_id36_conversion_blocked(): def test_zero_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept zero.""" """Ensure the ID36 conversion function doesn't accept zero."""
with raises(ValueError): with raises(ValueError):
id36_to_id('0')
id36_to_id("0")
def test_negative_id_conversion_blocked(): def test_negative_id_conversion_blocked():
@ -48,4 +48,4 @@ def test_negative_id_conversion_blocked():
def test_negative_id36_conversion_blocked(): def test_negative_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept negative numbers.""" """Ensure the ID36 conversion function doesn't accept negative numbers."""
with raises(ValueError): with raises(ValueError):
id36_to_id('-1')
id36_to_id("-1")

167
tildes/tests/test_markdown.py

@ -3,29 +3,29 @@ from tildes.lib.markdown import convert_markdown_to_safe_html
def test_script_tag_escaped(): def test_script_tag_escaped():
"""Ensure that a <script> tag can't get through.""" """Ensure that a <script> tag can't get through."""
markdown = '<script>alert()</script>'
markdown = "<script>alert()</script>"
sanitized = convert_markdown_to_safe_html(markdown) sanitized = convert_markdown_to_safe_html(markdown)
assert '<script>' not in sanitized
assert "<script>" not in sanitized
def test_basic_markdown_unescaped(): def test_basic_markdown_unescaped():
"""Test that some common markdown comes through without escaping.""" """Test that some common markdown comes through without escaping."""
markdown = ( markdown = (
"# Here's a header.\n\n" "# Here's a header.\n\n"
'This chunk of text has **some bold** and *some italics* in it.\n\n'
'A separator will be below this paragraph.\n\n'
'---\n\n'
'* An unordered list item\n'
'* Another list item\n\n'
'> This should be a quote.\n\n'
' And a code block\n\n'
'Also some `inline code` and [a link](http://example.com).\n\n'
'And a manual break \nbetween lines.\n\n'
"This chunk of text has **some bold** and *some italics* in it.\n\n"
"A separator will be below this paragraph.\n\n"
"---\n\n"
"* An unordered list item\n"
"* Another list item\n\n"
"> This should be a quote.\n\n"
" And a code block\n\n"
"Also some `inline code` and [a link](http://example.com).\n\n"
"And a manual break \nbetween lines.\n\n"
) )
sanitized = convert_markdown_to_safe_html(markdown) sanitized = convert_markdown_to_safe_html(markdown)
assert '&lt;' not in sanitized
assert "&lt;" not in sanitized
def test_strikethrough(): def test_strikethrough():
@ -33,23 +33,23 @@ def test_strikethrough():
markdown = "This ~should not~ should work" markdown = "This ~should not~ should work"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<del>' in processed
assert '<a' not in processed
assert "<del>" in processed
assert "<a" not in processed
def test_table(): def test_table():
"""Ensure table markdown works.""" """Ensure table markdown works."""
markdown = ( markdown = (
'|Header 1|Header 2|Header 3|\n'
'|--------|-------:|:------:|\n'
'|1 - 1 |1 - 2 |1 - 3 |\n'
'|2 - 1|2 - 2|2 - 3|\n'
"|Header 1|Header 2|Header 3|\n"
"|--------|-------:|:------:|\n"
"|1 - 1 |1 - 2 |1 - 3 |\n"
"|2 - 1|2 - 2|2 - 3|\n"
) )
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<table>' in processed
assert processed.count('<tr') == 3
assert processed.count('<td') == 6
assert "<table>" in processed
assert processed.count("<tr") == 3
assert processed.count("<td") == 6
assert 'align="right"' in processed assert 'align="right"' in processed
assert 'align="center"' in processed assert 'align="center"' in processed
@ -57,35 +57,35 @@ def test_table():
def test_deliberate_ordered_list(): def test_deliberate_ordered_list():
"""Ensure a "deliberate" ordered list works.""" """Ensure a "deliberate" ordered list works."""
markdown = ( markdown = (
'My first line of text.\n\n'
'1. I want\n'
'2. An ordered\n'
'3. List here\n\n'
'A final line.'
"My first line of text.\n\n"
"1. I want\n"
"2. An ordered\n"
"3. List here\n\n"
"A final line."
) )
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert '<ol>' in html
assert "<ol>" in html
def test_accidental_ordered_list(): def test_accidental_ordered_list():
"""Ensure a common "accidental" ordered list gets escaped.""" """Ensure a common "accidental" ordered list gets escaped."""
markdown = ( markdown = (
'What year did this happen?\n\n'
'1975. It was a long time ago.\n\n'
'But I remember it like it was yesterday.'
"What year did this happen?\n\n"
"1975. It was a long time ago.\n\n"
"But I remember it like it was yesterday."
) )
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert '<ol' not in html
assert "<ol" not in html
def test_existing_newline_not_doubled(): def test_existing_newline_not_doubled():
"""Ensure that the standard markdown line break doesn't result in two.""" """Ensure that the standard markdown line break doesn't result in two."""
markdown = 'A deliberate line \nbreak'
markdown = "A deliberate line \nbreak"
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == 1
assert html.count("<br") == 1
def test_newline_creates_br(): def test_newline_creates_br():
@ -93,36 +93,31 @@ def test_newline_creates_br():
markdown = "This wouldn't\nnormally work" markdown = "This wouldn't\nnormally work"
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert '<br>' in html
assert "<br>" in html
def test_multiple_newlines(): def test_multiple_newlines():
"""Ensure markdown with multiple newlines has expected result.""" """Ensure markdown with multiple newlines has expected result."""
lines = ["One.", "Two.", "Three.", "Four.", "Five."] lines = ["One.", "Two.", "Three.", "Four.", "Five."]
markdown = '\n'.join(lines)
markdown = "\n".join(lines)
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == len(lines) - 1
assert html.count("<br") == len(lines) - 1
assert all(line in html for line in lines) assert all(line in html for line in lines)
def test_newline_in_code_block(): def test_newline_in_code_block():
"""Ensure newlines in code blocks don't add a <br>.""" """Ensure newlines in code blocks don't add a <br>."""
markdown = (
'```\n'
'def testing_for_newlines():\n'
' pass\n'
'```\n'
)
markdown = "```\ndef testing_for_newlines():\n pass\n```\n"
html = convert_markdown_to_safe_html(markdown) html = convert_markdown_to_safe_html(markdown)
assert '<br' not in html
assert "<br" not in html
def test_http_link_linkified(): def test_http_link_linkified():
"""Ensure that writing an http url results in a link.""" """Ensure that writing an http url results in a link."""
markdown = 'I like http://example.com as an example.'
markdown = "I like http://example.com as an example."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed assert '<a href="http://example.com">' in processed
@ -130,7 +125,7 @@ def test_http_link_linkified():
def test_https_link_linkified(): def test_https_link_linkified():
"""Ensure that writing an https url results in a link.""" """Ensure that writing an https url results in a link."""
markdown = 'Also, https://example.com should work.'
markdown = "Also, https://example.com should work."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="https://example.com">' in processed assert '<a href="https://example.com">' in processed
@ -138,7 +133,7 @@ def test_https_link_linkified():
def test_bare_domain_linkified(): def test_bare_domain_linkified():
"""Ensure that a bare domain results in a link.""" """Ensure that a bare domain results in a link."""
markdown = 'I can just write example.com too.'
markdown = "I can just write example.com too."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed assert '<a href="http://example.com">' in processed
@ -146,7 +141,7 @@ def test_bare_domain_linkified():
def test_link_with_path_linkified(): def test_link_with_path_linkified():
"""Ensure a link with a path results in a link.""" """Ensure a link with a path results in a link."""
markdown = 'So http://example.com/a/b_c_d/e too?'
markdown = "So http://example.com/a/b_c_d/e too?"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com/a/b_c_d/e">' in processed assert '<a href="http://example.com/a/b_c_d/e">' in processed
@ -154,7 +149,7 @@ def test_link_with_path_linkified():
def test_link_with_query_string_linkified(): def test_link_with_query_string_linkified():
"""Ensure a link with a query string results in a link.""" """Ensure a link with a query string results in a link."""
markdown = 'Also http://example.com?something=true works?'
markdown = "Also http://example.com?something=true works?"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com?something=true">' in processed assert '<a href="http://example.com?something=true">' in processed
@ -162,21 +157,21 @@ def test_link_with_query_string_linkified():
def test_email_address_not_linkified(): def test_email_address_not_linkified():
"""Ensure that an email address does not get linkified.""" """Ensure that an email address does not get linkified."""
markdown = 'Please contact somebody@example.com about that.'
markdown = "Please contact somebody@example.com about that."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_other_protocol_urls_not_linkified(): def test_other_protocol_urls_not_linkified():
"""Ensure some other protocols don't linkify (not comprehensive).""" """Ensure some other protocols don't linkify (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols: for protocol in protocols:
markdown = f'Testing {protocol}://example.com for linking'
markdown = f"Testing {protocol}://example.com for linking"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_html_attr_whitelist_violation(): def test_html_attr_whitelist_violation():
@ -187,23 +182,23 @@ def test_html_attr_whitelist_violation():
) )
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert processed == '<p>test link</p>\n'
assert processed == "<p>test link</p>\n"
def test_a_href_protocol_violation(): def test_a_href_protocol_violation():
"""Ensure link to other protocols removes the link (not comprehensive).""" """Ensure link to other protocols removes the link (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols: for protocol in protocols:
markdown = f'Testing [a link]({protocol}://example.com) for linking'
markdown = f"Testing [a link]({protocol}://example.com) for linking"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert 'href' not in processed
assert "href" not in processed
def test_group_reference_linkified(): def test_group_reference_linkified():
"""Ensure a simple group reference gets linkified.""" """Ensure a simple group reference gets linkified."""
markdown = 'Yeah, I saw that in ~books.fantasy yesterday.'
markdown = "Yeah, I saw that in ~books.fantasy yesterday."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~books.fantasy">' in processed assert '<a href="/~books.fantasy">' in processed
@ -212,14 +207,14 @@ def test_group_reference_linkified():
def test_multiple_group_references_linkified(): def test_multiple_group_references_linkified():
"""Ensure multiple group references are all linkified.""" """Ensure multiple group references are all linkified."""
markdown = ( markdown = (
'I like to keep an eye on:\n\n'
'* ~music.metal\n'
'* ~music.metal.progressive\n'
'* ~music.post_rock\n'
"I like to keep an eye on:\n\n"
"* ~music.metal\n"
"* ~music.metal.progressive\n"
"* ~music.post_rock\n"
) )
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 3
assert processed.count("<a") == 3
def test_invalid_group_reference_not_linkified(): def test_invalid_group_reference_not_linkified():
@ -230,20 +225,20 @@ def test_invalid_group_reference_not_linkified():
) )
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_approximately_tilde_not_linkified(): def test_approximately_tilde_not_linkified():
"""Ensure a tilde in front of a number doesn't linkify.""" """Ensure a tilde in front of a number doesn't linkify."""
markdown = 'Mix in ~2 cups of flour and ~1.5 tbsp of sugar.'
markdown = "Mix in ~2 cups of flour and ~1.5 tbsp of sugar."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_uppercase_group_ref_links_correctly(): def test_uppercase_group_ref_links_correctly():
"""Ensure using uppercase in a group ref works but links correctly.""" """Ensure using uppercase in a group ref works but links correctly."""
markdown = 'That was in ~Music.Metal.Progressive'
markdown = "That was in ~Music.Metal.Progressive"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~music.metal.progressive' in processed assert '<a href="/~music.metal.progressive' in processed
@ -260,29 +255,29 @@ def test_existing_link_group_ref_not_replaced():
def test_group_ref_inside_link_not_replaced(): def test_group_ref_inside_link_not_replaced():
"""Ensure a group ref inside a longer link doesn't get re-linked.""" """Ensure a group ref inside a longer link doesn't get re-linked."""
markdown = 'Found [this band from a ~music.punk post](http://whitelung.ca)'
markdown = "Found [this band from a ~music.punk post](http://whitelung.ca)"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 1
assert processed.count("<a") == 1
assert 'href="/~music.punk"' not in processed assert 'href="/~music.punk"' not in processed
def test_group_ref_inside_pre_ignored(): def test_group_ref_inside_pre_ignored():
"""Ensure a group ref inside a <pre> tag doesn't get linked.""" """Ensure a group ref inside a <pre> tag doesn't get linked."""
markdown = ( markdown = (
'```\n'
'# This is a code block\n'
'# I found this code on ~comp.lang.python\n'
'```\n'
"```\n"
"# This is a code block\n"
"# I found this code on ~comp.lang.python\n"
"```\n"
) )
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_group_ref_inside_other_tags_linkified(): def test_group_ref_inside_other_tags_linkified():
"""Ensure a group ref inside non-ignored tags gets linked.""" """Ensure a group ref inside non-ignored tags gets linked."""
markdown = '> Here is **a ~group.reference inside** other stuff'
markdown = "> Here is **a ~group.reference inside** other stuff"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~group.reference">' in processed assert '<a href="/~group.reference">' in processed
@ -290,7 +285,7 @@ def test_group_ref_inside_other_tags_linkified():
def test_username_reference_linkified(): def test_username_reference_linkified():
"""Ensure a basic username reference gets linkified.""" """Ensure a basic username reference gets linkified."""
markdown = 'Hey @SomeUser, what do you think of this?'
markdown = "Hey @SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -298,7 +293,7 @@ def test_username_reference_linkified():
def test_u_style_username_ref_linked(): def test_u_style_username_ref_linked():
"""Ensure a /u/username reference gets linkified.""" """Ensure a /u/username reference gets linkified."""
markdown = 'Hey /u/SomeUser, what do you think of this?'
markdown = "Hey /u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">/u/SomeUser</a>' in processed assert '<a href="/user/SomeUser">/u/SomeUser</a>' in processed
@ -306,7 +301,7 @@ def test_u_style_username_ref_linked():
def test_u_alt_style_username_ref_linked(): def test_u_alt_style_username_ref_linked():
"""Ensure a u/username reference gets linkified.""" """Ensure a u/username reference gets linkified."""
markdown = 'Hey u/SomeUser, what do you think of this?'
markdown = "Hey u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">u/SomeUser</a>' in processed assert '<a href="/user/SomeUser">u/SomeUser</a>' in processed
@ -314,15 +309,15 @@ def test_u_alt_style_username_ref_linked():
def test_accidental_u_alt_style_not_linked(): def test_accidental_u_alt_style_not_linked():
"""Ensure an "accidental" u/ usage won't get linked.""" """Ensure an "accidental" u/ usage won't get linked."""
markdown = 'I think those are caribou/reindeer.'
markdown = "I think those are caribou/reindeer."
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_and_group_refs_linked(): def test_username_and_group_refs_linked():
"""Ensure username and group references together get linkified.""" """Ensure username and group references together get linkified."""
markdown = '@SomeUser makes the best posts in ~some.group for sure'
markdown = "@SomeUser makes the best posts in ~some.group for sure"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -334,16 +329,12 @@ def test_invalid_username_not_linkified():
markdown = "You can't register a username like @_underscores_" markdown = "You can't register a username like @_underscores_"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_ref_inside_pre_ignored(): def test_username_ref_inside_pre_ignored():
"""Ensure a username ref inside a <pre> tag doesn't get linked.""" """Ensure a username ref inside a <pre> tag doesn't get linked."""
markdown = (
'```\n'
'# Code blatantly stolen from @HelpfulGuy on StackOverflow\n'
'```\n'
)
markdown = "```\n# Code blatantly stolen from @HelpfulGuy on StackOverflow\n```\n"
processed = convert_markdown_to_safe_html(markdown) processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed

22
tildes/tests/test_markdown_field.py

@ -12,53 +12,53 @@ class MarkdownFieldTestSchema(Schema):
def validate_string(string): def validate_string(string):
"""Validate a string against a standard Markdown field.""" """Validate a string against a standard Markdown field."""
MarkdownFieldTestSchema(strict=True).validate({'markdown': string})
MarkdownFieldTestSchema(strict=True).validate({"markdown": string})
def test_normal_text_validates(): def test_normal_text_validates():
"""Ensure some "normal-looking" markdown validates.""" """Ensure some "normal-looking" markdown validates."""
validate_string( validate_string(
"Here's some markdown.\n\n" "Here's some markdown.\n\n"
'It has **a bit of bold**, [a link](http://example.com)\n'
'> And `some code` in a blockquote'
"It has **a bit of bold**, [a link](http://example.com)\n"
"> And `some code` in a blockquote"
) )
def test_changing_max_length(): def test_changing_max_length():
"""Ensure changing the max_length argument works.""" """Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate # should normally validate
assert Markdown()._validate(test_string) is None assert Markdown()._validate(test_string) is None
# but fails if you set a too-short max_length # but fails if you set a too-short max_length
with raises(ValidationError): with raises(ValidationError):
Markdown(max_length=len(test_string)-1)._validate(test_string)
Markdown(max_length=len(test_string) - 1)._validate(test_string)
def test_extremely_long_string(): def test_extremely_long_string():
"""Ensure an extremely long string fails validation.""" """Ensure an extremely long string fails validation."""
with raises(ValidationError): with raises(ValidationError):
validate_string('A' * 100_000)
validate_string("A" * 100_000)
def test_empty_string(): def test_empty_string():
"""Ensure an empty string fails validation.""" """Ensure an empty string fails validation."""
with raises(ValidationError): with raises(ValidationError):
validate_string('')
validate_string("")
def test_all_whitespace_string(): def test_all_whitespace_string():
"""Ensure a string that's all whitespace chars fails validation.""" """Ensure a string that's all whitespace chars fails validation."""
with raises(ValidationError): with raises(ValidationError):
validate_string(' \n \n\r\n \t ')
validate_string(" \n \n\r\n \t ")
def test_carriage_returns_stripped(): def test_carriage_returns_stripped():
"""Ensure loading a value strips out carriage returns from the string.""" """Ensure loading a value strips out carriage returns from the string."""
test_string = 'some\r\nreturns\r\nin\nhere'
test_string = "some\r\nreturns\r\nin\nhere"
schema = MarkdownFieldTestSchema(strict=True) schema = MarkdownFieldTestSchema(strict=True)
result = schema.load({'markdown': test_string})
result = schema.load({"markdown": test_string})
assert '\r' not in result.data['markdown']
assert "\r" not in result.data["markdown"]

36
tildes/tests/test_messages.py

@ -4,17 +4,15 @@ from pytest import fixture, raises
from tildes.models.message import MessageConversation, MessageReply from tildes.models.message import MessageConversation, MessageReply
from tildes.models.user import User from tildes.models.user import User
from tildes.schemas.fields import Markdown, SimpleString from tildes.schemas.fields import Markdown, SimpleString
from tildes.schemas.message import (
MessageConversationSchema,
MessageReplySchema,
)
from tildes.schemas.message import MessageConversationSchema, MessageReplySchema
@fixture @fixture
def conversation(db, session_user, session_user2): def conversation(db, session_user, session_user2):
"""Create a message conversation and delete it as teardown.""" """Create a message conversation and delete it as teardown."""
new_conversation = MessageConversation( new_conversation = MessageConversation(
session_user, session_user2, 'Subject', 'Message')
session_user, session_user2, "Subject", "Message"
)
db.add(new_conversation) db.add(new_conversation)
db.commit() db.commit()
@ -30,31 +28,31 @@ def conversation(db, session_user, session_user2):
def test_message_conversation_validation(mocker, session_user, session_user2): def test_message_conversation_validation(mocker, session_user, session_user2):
"""Ensure a new message conversation goes through expected validation.""" """Ensure a new message conversation goes through expected validation."""
mocker.spy(MessageConversationSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageConversationSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(Markdown, "_validate")
MessageConversation(session_user, session_user2, 'Subject', 'Message')
MessageConversation(session_user, session_user2, "Subject", "Message")
assert MessageConversationSchema.load.called assert MessageConversationSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'Subject'
assert Markdown._validate.call_args[0][1] == 'Message'
assert SimpleString._validate.call_args[0][1] == "Subject"
assert Markdown._validate.call_args[0][1] == "Message"
def test_message_reply_validation(mocker, conversation, session_user2): def test_message_reply_validation(mocker, conversation, session_user2):
"""Ensure a new message reply goes through expected validation.""" """Ensure a new message reply goes through expected validation."""
mocker.spy(MessageReplySchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageReplySchema, "load")
mocker.spy(Markdown, "_validate")
MessageReply(conversation, session_user2, 'A new reply')
MessageReply(conversation, session_user2, "A new reply")
assert MessageReplySchema.load.called assert MessageReplySchema.load.called
assert Markdown._validate.call_args[0][1] == 'A new reply'
assert Markdown._validate.call_args[0][1] == "A new reply"
def test_conversation_viewing_permission(conversation): def test_conversation_viewing_permission(conversation):
"""Ensure only the two involved users can view a message conversation.""" """Ensure only the two involved users can view a message conversation."""
principals = principals_allowed_by_permission(conversation, 'view')
principals = principals_allowed_by_permission(conversation, "view")
users = {conversation.sender.user_id, conversation.recipient.user_id} users = {conversation.sender.user_id, conversation.recipient.user_id}
assert principals == users assert principals == users
@ -70,7 +68,7 @@ def test_conversation_other_user(conversation):
def test_conversation_other_user_invalid(conversation): def test_conversation_other_user_invalid(conversation):
"""Ensure that "other user" method fails if the user isn't involved.""" """Ensure that "other user" method fails if the user isn't involved."""
new_user = User('SomeOutsider', 'super amazing password')
new_user = User("SomeOutsider", "super amazing password")
with raises(ValueError): with raises(ValueError):
assert conversation.other_user(new_user) assert conversation.other_user(new_user)
@ -82,7 +80,7 @@ def test_replies_affect_num_replies(conversation, db):
# add replies and ensure each one increases the count # add replies and ensure each one increases the count
for num in range(5): for num in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply) db.add(new_reply)
db.commit() db.commit()
db.refresh(conversation) db.refresh(conversation)
@ -94,7 +92,7 @@ def test_replies_update_activity_time(conversation, db):
assert conversation.last_activity_time == conversation.created_time assert conversation.last_activity_time == conversation.created_time
for _ in range(5): for _ in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply) db.add(new_reply)
db.commit() db.commit()

2
tildes/tests/test_metrics.py

@ -9,4 +9,4 @@ def test_all_metric_names_prefixed():
# this is ugly, but seems to be the "generic" way to get the name # this is ugly, but seems to be the "generic" way to get the name
metric_name = metric.describe()[0].name metric_name = metric.describe()[0].name
assert metric_name.startswith('tildes_')
assert metric_name.startswith("tildes_")

36
tildes/tests/test_ratelimit.py

@ -24,13 +24,12 @@ def test_all_rate_limited_action_names_unique():
def test_action_with_all_types_disabled(): def test_action_with_all_types_disabled():
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled.""" """Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
with raises(ValueError): with raises(ValueError):
RateLimitedAction(
'test', timedelta(hours=1), 5, by_user=False, by_ip=False)
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
def test_check_by_user_id_disabled(): def test_check_by_user_id_disabled():
"""Ensure non-by_user RateLimitedAction can't be checked by user_id.""" """Ensure non-by_user RateLimitedAction can't be checked by user_id."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_user=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
with raises(RateLimitError): with raises(RateLimitError):
action.check_for_user_id(1) action.check_for_user_id(1)
@ -38,10 +37,10 @@ def test_check_by_user_id_disabled():
def test_check_by_ip_disabled(): def test_check_by_ip_disabled():
"""Ensure non-by_ip RateLimitedAction can't be checked by ip.""" """Ensure non-by_ip RateLimitedAction can't be checked by ip."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_ip=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
with raises(RateLimitError): with raises(RateLimitError):
action.check_for_ip('123.123.123.123')
action.check_for_ip("123.123.123.123")
def test_simple_rate_limiting_by_user_id(redis): def test_simple_rate_limiting_by_user_id(redis):
@ -51,7 +50,8 @@ def test_simple_rate_limiting_by_user_id(redis):
# define an action with max_burst equal to the full limit # define an action with max_burst equal to the full limit
action = RateLimitedAction( action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed # run the action the full number of times, should all be allowed
for _ in range(limit): for _ in range(limit):
@ -68,7 +68,7 @@ def test_different_user_ids_limited_separately(redis):
limit = 5 limit = 5
user_id = 1 user_id = 1
action = RateLimitedAction('test', timedelta(hours=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
# check the action for the first user_id until it's blocked # check the action for the first user_id until it's blocked
result = action.check_for_user_id(user_id) result = action.check_for_user_id(user_id)
@ -84,7 +84,7 @@ def test_max_burst_defaults_to_half(redis):
limit = 10 limit = 10
user_id = 1 user_id = 1
action = RateLimitedAction('test', timedelta(days=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
# see how many times we can do the action until it gets blocked # see how many times we can do the action until it gets blocked
count = 0 count = 0
@ -107,7 +107,8 @@ def test_time_until_retry(redis):
# create an action with no burst allowed, which will force the actions to # create an action with no burst allowed, which will force the actions to
# be spaced "evenly" across the limit # be spaced "evenly" across the limit
action = RateLimitedAction( action = RateLimitedAction(
'test', period=period, limit=limit, max_burst=1, redis=redis)
"test", period=period, limit=limit, max_burst=1, redis=redis
)
# first usage should be fine # first usage should be fine
result = action.check_for_user_id(user_id) result = action.check_for_user_id(user_id)
@ -126,7 +127,8 @@ def test_remaining_limit(redis):
# create an action allowing the full limit as a burst # create an action allowing the full limit as a burst
action = RateLimitedAction( action = RateLimitedAction(
'test', timedelta(days=1), limit, max_burst=limit, redis=redis)
"test", timedelta(days=1), limit, max_burst=limit, redis=redis
)
for count in range(1, limit + 1): for count in range(1, limit + 1):
result = action.check_for_user_id(user_id) result = action.check_for_user_id(user_id)
@ -136,11 +138,12 @@ def test_remaining_limit(redis):
def test_simple_rate_limiting_by_ip(redis): def test_simple_rate_limiting_by_ip(redis):
"""Ensure simple rate-limiting by IP address is working.""" """Ensure simple rate-limiting by IP address is working."""
limit = 5 limit = 5
ip = '123.123.123.123'
ip = "123.123.123.123"
# define an action with max_burst equal to the full limit # define an action with max_burst equal to the full limit
action = RateLimitedAction( action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed # run the action the full number of times, should all be allowed
for _ in range(limit): for _ in range(limit):
@ -154,9 +157,9 @@ def test_simple_rate_limiting_by_ip(redis):
def test_check_for_ip_invalid_address(): def test_check_for_ip_invalid_address():
"""Ensure RateLimitedAction.check_for_ip can't take an invalid IP.""" """Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError): with raises(ValueError):
action.check_for_ip(ip) action.check_for_ip(ip)
@ -164,9 +167,9 @@ def test_check_for_ip_invalid_address():
def test_reset_for_ip_invalid_address(): def test_reset_for_ip_invalid_address():
"""Ensure RateLimitedAction.reset_for_ip can't take an invalid IP.""" """Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError): with raises(ValueError):
action.reset_for_ip(ip) action.reset_for_ip(ip)
@ -224,6 +227,7 @@ def test_merged_results():
def test_merged_all_allowed(): def test_merged_all_allowed():
"""Ensure a merged result from all allowed results is also allowed.""" """Ensure a merged result from all allowed results is also allowed."""
def random_allowed_result(): def random_allowed_result():
"""Return a RateLimitResult with is_allowed=True, otherwise random.""" """Return a RateLimitResult with is_allowed=True, otherwise random."""
return RateLimitResult( return RateLimitResult(

22
tildes/tests/test_simplestring_field.py

@ -17,39 +17,39 @@ def process_string(string):
ValidationError if an invalid string is attempted. ValidationError if an invalid string is attempted.
""" """
schema = SimpleStringTestSchema(strict=True) schema = SimpleStringTestSchema(strict=True)
result = schema.load({'subject': string})
result = schema.load({"subject": string})
return result.data['subject']
return result.data["subject"]
def test_changing_max_length(): def test_changing_max_length():
"""Ensure changing the max_length argument works.""" """Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate # should normally validate
assert SimpleString()._validate(test_string) is None assert SimpleString()._validate(test_string) is None
# but fails if you set a too-short max_length # but fails if you set a too-short max_length
with raises(ValidationError): with raises(ValidationError):
SimpleString(max_length=len(test_string)-1)._validate(test_string)
SimpleString(max_length=len(test_string) - 1)._validate(test_string)
def test_long_string(): def test_long_string():
"""Ensure a long string fails validation.""" """Ensure a long string fails validation."""
with raises(ValidationError): with raises(ValidationError):
process_string('A' * 10_000)
process_string("A" * 10_000)
def test_empty_string(): def test_empty_string():
"""Ensure an empty string fails validation.""" """Ensure an empty string fails validation."""
with raises(ValidationError): with raises(ValidationError):
process_string('')
process_string("")
def test_all_whitespace_string(): def test_all_whitespace_string():
"""Ensure a string that's entirely whitespace fails validation.""" """Ensure a string that's entirely whitespace fails validation."""
with raises(ValidationError): with raises(ValidationError):
process_string('\n \t \r\n ')
process_string("\n \t \r\n ")
def test_normal_string_untouched(): def test_normal_string_untouched():
@ -76,11 +76,11 @@ def test_control_chars_removed():
def test_leading_trailing_spaces_removed(): def test_leading_trailing_spaces_removed():
"""Ensure leading/trailing spaces are removed from the string.""" """Ensure leading/trailing spaces are removed from the string."""
original = ' Centered! '
assert process_string(original) == 'Centered!'
original = " Centered! "
assert process_string(original) == "Centered!"
def test_consecutive_spaces_collapsed(): def test_consecutive_spaces_collapsed():
"""Ensure runs of consecutive spaces are "collapsed" inside the string.""" """Ensure runs of consecutive spaces are "collapsed" inside the string."""
original = 'I wanted to space this out'
assert process_string(original) == 'I wanted to space this out'
original = "I wanted to space this out"
assert process_string(original) == "I wanted to space this out"

58
tildes/tests/test_string.py

@ -8,71 +8,69 @@ from tildes.lib.string import (
def test_simple_truncate(): def test_simple_truncate():
"""Ensure a simple truncation by length works correctly.""" """Ensure a simple truncation by length works correctly."""
truncated = truncate_string('123456789', 5, overflow_str=None)
assert truncated == '12345'
truncated = truncate_string("123456789", 5, overflow_str=None)
assert truncated == "12345"
def test_simple_truncate_with_overflow(): def test_simple_truncate_with_overflow():
"""Ensure a simple truncation by length with an overflow string works.""" """Ensure a simple truncation by length with an overflow string works."""
truncated = truncate_string('123456789', 5)
assert truncated == '12...'
truncated = truncate_string("123456789", 5)
assert truncated == "12..."
def test_truncate_same_length(): def test_truncate_same_length():
"""Ensure truncation doesn't happen if the string is the desired length.""" """Ensure truncation doesn't happen if the string is the desired length."""
original = '123456789'
original = "123456789"
assert truncate_string(original, len(original)) == original assert truncate_string(original, len(original)) == original
def test_truncate_at_char(): def test_truncate_at_char():
"""Ensure truncation at a particular character works.""" """Ensure truncation at a particular character works."""
original = 'asdf zxcv'
assert truncate_string_at_char(original, ' ') == 'asdf'
original = "asdf zxcv"
assert truncate_string_at_char(original, " ") == "asdf"
def test_truncate_at_last_char(): def test_truncate_at_last_char():
"""Ensure truncation happens at the last occurrence of the character.""" """Ensure truncation happens at the last occurrence of the character."""
original = 'as df zx cv'
assert truncate_string_at_char(original, ' ') == 'as df zx'
original = "as df zx cv"
assert truncate_string_at_char(original, " ") == "as df zx"
def test_truncate_at_nonexistent_char(): def test_truncate_at_nonexistent_char():
"""Ensure truncation-at-character doesn't apply if char isn't present.""" """Ensure truncation-at-character doesn't apply if char isn't present."""
original = 'asdfzxcv'
assert truncate_string_at_char(original, ' ') == original
original = "asdfzxcv"
assert truncate_string_at_char(original, " ") == original
def test_truncate_at_multiple_chars(): def test_truncate_at_multiple_chars():
"""Ensure truncation with multiple characters uses the rightmost one.""" """Ensure truncation with multiple characters uses the rightmost one."""
original = 'as-df=zx_cv'
assert truncate_string_at_char(original, '-=') == 'as-df'
original = "as-df=zx_cv"
assert truncate_string_at_char(original, "-=") == "as-df"
def test_truncate_length_and_char(): def test_truncate_length_and_char():
"""Ensure combined length+char truncation works as expected.""" """Ensure combined length+char truncation works as expected."""
original = '12345-67890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345'
original = "12345-67890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345"
def test_truncate_length_and_nonexistent_char(): def test_truncate_length_and_nonexistent_char():
"""Ensure length+char truncation works if the char isn't present.""" """Ensure length+char truncation works if the char isn't present."""
original = '1234567890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345678'
original = "1234567890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345678"
def test_simple_url_slug_conversion(): def test_simple_url_slug_conversion():
"""Ensure that a simple url slug conversion works as expected.""" """Ensure that a simple url slug conversion works as expected."""
assert convert_to_url_slug("A Simple Test") == 'a_simple_test'
assert convert_to_url_slug("A Simple Test") == "a_simple_test"
def test_url_slug_with_punctuation(): def test_url_slug_with_punctuation():
"""Ensure url slug conversion with punctuation works as expected.""" """Ensure url slug conversion with punctuation works as expected."""
original = "Here's a string. It has (some) punctuation!" original = "Here's a string. It has (some) punctuation!"
expected = 'heres_a_string_it_has_some_punctuation'
expected = "heres_a_string_it_has_some_punctuation"
assert convert_to_url_slug(original) == expected assert convert_to_url_slug(original) == expected
@ -86,13 +84,13 @@ def test_url_slug_with_apostrophes():
def test_url_slug_truncation(): def test_url_slug_truncation():
"""Ensure a simple url slug truncates as expected.""" """Ensure a simple url slug truncates as expected."""
original = "Here's another string to truncate." original = "Here's another string to truncate."
assert convert_to_url_slug(original, 15) == 'heres_another'
assert convert_to_url_slug(original, 15) == "heres_another"
def test_multibyte_url_slug(): def test_multibyte_url_slug():
"""Ensure converting/truncating a slug with encoded characters works.""" """Ensure converting/truncating a slug with encoded characters works."""
original = 'Python ist eine üblicherweise höhere Programmiersprache'
expected = 'python_ist_eine_%C3%BCblicherweise'
original = "Python ist eine üblicherweise höhere Programmiersprache"
expected = "python_ist_eine_%C3%BCblicherweise"
assert convert_to_url_slug(original, 45) == expected assert convert_to_url_slug(original, 45) == expected
@ -101,7 +99,7 @@ def test_multibyte_conservative_truncation():
# this string has a comma as the 6th char which will be converted to an # this string has a comma as the 6th char which will be converted to an
# underscore, so if truncation amount isn't restricted, it would result in # underscore, so if truncation amount isn't restricted, it would result in
# a 46-char slug instead of the full 100. # a 46-char slug instead of the full 100.
original = 'パイソンは、汎用のプログラミング言語である'
original = "パイソンは、汎用のプログラミング言語である"
assert len(convert_to_url_slug(original, 100)) == 100 assert len(convert_to_url_slug(original, 100)) == 100
@ -109,14 +107,14 @@ def test_multibyte_whole_character_truncation():
"""Ensure truncation happens at the edge of a multibyte character.""" """Ensure truncation happens at the edge of a multibyte character."""
# each of these characters url-encodes to 3 bytes = 9 characters each, so # each of these characters url-encodes to 3 bytes = 9 characters each, so
# only the first character should be included for all lengths from 9 - 17 # only the first character should be included for all lengths from 9 - 17
original = 'コード'
original = "コード"
for limit in range(9, 18): for limit in range(9, 18):
assert convert_to_url_slug(original, limit) == '%E3%82%B3'
assert convert_to_url_slug(original, limit) == "%E3%82%B3"
def test_simple_word_count(): def test_simple_word_count():
"""Ensure word-counting a simple string works as expected.""" """Ensure word-counting a simple string works as expected."""
string = 'Here is a simple string of words, nothing fancy.'
string = "Here is a simple string of words, nothing fancy."
assert word_count(string) == 9 assert word_count(string) == 9

36
tildes/tests/test_title.py

@ -7,57 +7,57 @@ from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
@fixture @fixture
def title_schema(): def title_schema():
"""Fixture for generating a title-only TopicSchema.""" """Fixture for generating a title-only TopicSchema."""
return TopicSchema(only=('title',))
return TopicSchema(only=("title",))
def test_typical_title_valid(title_schema): def test_typical_title_valid(title_schema):
"""Test a "normal-looking" title to make sure it's valid.""" """Test a "normal-looking" title to make sure it's valid."""
title = "[Something] Here's an article that I'm sure 100 people will like." title = "[Something] Here's an article that I'm sure 100 people will like."
assert title_schema.validate({'title': title}) == {}
assert title_schema.validate({"title": title}) == {}
def test_too_long_title_invalid(title_schema): def test_too_long_title_invalid(title_schema):
"""Ensure a too-long title is invalid.""" """Ensure a too-long title is invalid."""
title = 'x' * (TITLE_MAX_LENGTH + 1)
title = "x" * (TITLE_MAX_LENGTH + 1)
with raises(ValidationError): with raises(ValidationError):
title_schema.validate({'title': title})
title_schema.validate({"title": title})
def test_empty_title_invalid(title_schema): def test_empty_title_invalid(title_schema):
"""Ensure an empty title is invalid.""" """Ensure an empty title is invalid."""
with raises(ValidationError): with raises(ValidationError):
title_schema.validate({'title': ''})
title_schema.validate({"title": ""})
def test_whitespace_only_title_invalid(title_schema): def test_whitespace_only_title_invalid(title_schema):
"""Ensure a whitespace-only title is invalid.""" """Ensure a whitespace-only title is invalid."""
with raises(ValidationError): with raises(ValidationError):
title_schema.validate({'title': ' \n '})
title_schema.validate({"title": " \n "})
def test_whitespace_trimmed(title_schema): def test_whitespace_trimmed(title_schema):
"""Ensure leading/trailing whitespace on a title is removed.""" """Ensure leading/trailing whitespace on a title is removed."""
title = ' actual title '
result = title_schema.load({'title': title})
assert result.data['title'] == 'actual title'
title = " actual title "
result = title_schema.load({"title": title})
assert result.data["title"] == "actual title"
def test_consecutive_whitespace_removed(title_schema): def test_consecutive_whitespace_removed(title_schema):
"""Ensure consecutive whitespace in a title is compressed.""" """Ensure consecutive whitespace in a title is compressed."""
title = 'sure are \n a lot of spaces'
result = title_schema.load({'title': title})
assert result.data['title'] == 'sure are a lot of spaces'
title = "sure are \n a lot of spaces"
result = title_schema.load({"title": title})
assert result.data["title"] == "sure are a lot of spaces"
def test_unicode_spaces_normalized(title_schema): def test_unicode_spaces_normalized(title_schema):
"""Test that some unicode space characters are converted to normal ones.""" """Test that some unicode space characters are converted to normal ones."""
title = 'some\u2009weird\u00a0spaces\u205fin\u00a0here'
result = title_schema.load({'title': title})
assert result.data['title'] == 'some weird spaces in here'
title = "some\u2009weird\u00a0spaces\u205fin\u00a0here"
result = title_schema.load({"title": title})
assert result.data["title"] == "some weird spaces in here"
def test_unicode_control_chars_removed(title_schema): def test_unicode_control_chars_removed(title_schema):
"""Test that some unicode control characters are stripped from titles.""" """Test that some unicode control characters are stripped from titles."""
title = 'nothing\u0000strange\u0085going\u009con\u007fhere'
result = title_schema.load({'title': title})
assert result.data['title'] == 'nothingstrangegoingonhere'
title = "nothing\u0000strange\u0085going\u009con\u007fhere"
result = title_schema.load({"title": title})
assert result.data["title"] == "nothingstrangegoingonhere"

44
tildes/tests/test_topic.py

@ -12,41 +12,37 @@ from tildes.schemas.topic import TopicSchema
def test_text_creation_validations(mocker, session_user, session_group): def test_text_creation_validations(mocker, session_user, session_group):
"""Ensure that text topic creation goes through expected validation.""" """Ensure that text topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(Markdown, "_validate")
mocker.spy(SimpleString, "_validate")
Topic.create_text_topic(
session_group, session_user, 'a title', 'the text')
Topic.create_text_topic(session_group, session_user, "a title", "the text")
assert TopicSchema.load.called assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'a title'
assert Markdown._validate.call_args[0][1] == 'the text'
assert SimpleString._validate.call_args[0][1] == "a title"
assert Markdown._validate.call_args[0][1] == "the text"
def test_link_creation_validations(mocker, session_user, session_group): def test_link_creation_validations(mocker, session_user, session_group):
"""Ensure that link topic creation goes through expected validation.""" """Ensure that link topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(URL, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(URL, "_validate")
Topic.create_link_topic( Topic.create_link_topic(
session_group,
session_user,
'the title',
'http://example.com',
session_group, session_user, "the title", "http://example.com"
) )
assert TopicSchema.load.called assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'the title'
assert URL._validate.call_args[0][1] == 'http://example.com'
assert SimpleString._validate.call_args[0][1] == "the title"
assert URL._validate.call_args[0][1] == "http://example.com"
def test_text_topic_edit_uses_markdown_field(mocker, text_topic): def test_text_topic_edit_uses_markdown_field(mocker, text_topic):
"""Ensure editing a text topic is validated by the Markdown field class.""" """Ensure editing a text topic is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
text_topic.markdown = 'Some new text after edit'
text_topic.markdown = "Some new text after edit"
assert Markdown._validate.called assert Markdown._validate.called
@ -82,19 +78,19 @@ def test_link_domain_errors_on_text_topic(text_topic):
def test_link_domain_on_link_topic(link_topic): def test_link_domain_on_link_topic(link_topic):
"""Ensure getting the domain of a link topic works.""" """Ensure getting the domain of a link topic works."""
assert link_topic.link_domain == 'example.com'
assert link_topic.link_domain == "example.com"
def test_edit_markdown_errors_on_link_topic(link_topic): def test_edit_markdown_errors_on_link_topic(link_topic):
"""Ensure trying to edit the markdown of a link topic is an error.""" """Ensure trying to edit the markdown of a link topic is an error."""
with raises(AttributeError): with raises(AttributeError):
link_topic.markdown = 'Some new markdown'
link_topic.markdown = "Some new markdown"
def test_edit_markdown_on_text_topic(text_topic): def test_edit_markdown_on_text_topic(text_topic):
"""Ensure editing the markdown of a text topic works and updates html.""" """Ensure editing the markdown of a text topic works and updates html."""
original_html = text_topic.rendered_html original_html = text_topic.rendered_html
text_topic.markdown = 'Some new markdown'
text_topic.markdown = "Some new markdown"
assert text_topic.rendered_html != original_html assert text_topic.rendered_html != original_html
@ -104,7 +100,7 @@ def test_edit_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD - one_sec edit_time = text_topic.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time): with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert not text_topic.last_edited_time assert not text_topic.last_edited_time
@ -115,7 +111,7 @@ def test_edit_after_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD + one_sec edit_time = text_topic.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time): with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert text_topic.last_edited_time == utc_now() assert text_topic.last_edited_time == utc_now()
@ -127,7 +123,7 @@ def test_multiple_edits_update_time(text_topic):
for minutes in range(0, 4): for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes) edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time): with freeze_time(edit_time):
text_topic.markdown = f'edit #{minutes}'
text_topic.markdown = f"edit #{minutes}"
assert text_topic.last_edited_time == utc_now() assert text_topic.last_edited_time == utc_now()

37
tildes/tests/test_topic_permissions.py

@ -1,13 +1,9 @@
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
def test_topic_viewing_permission(text_topic): def test_topic_viewing_permission(text_topic):
"""Ensure that anyone can view a topic by default.""" """Ensure that anyone can view a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'view')
principals = principals_allowed_by_permission(text_topic, "view")
assert Everyone in principals assert Everyone in principals
@ -15,71 +11,70 @@ def test_deleted_topic_permissions_removed(topic):
"""Ensure that deleted topics lose all permissions except "view".""" """Ensure that deleted topics lose all permissions except "view"."""
topic.is_deleted = True topic.is_deleted = True
assert principals_allowed_by_permission(topic, 'view') == {Everyone}
assert principals_allowed_by_permission(topic, "view") == {Everyone}
all_permissions = [
perm for (_, _, perm) in topic.__acl__() if perm != 'view']
all_permissions = [perm for (_, _, perm) in topic.__acl__() if perm != "view"]
for permission in all_permissions: for permission in all_permissions:
assert not principals_allowed_by_permission(topic, permission) assert not principals_allowed_by_permission(topic, permission)
def test_text_topic_editing_permission(text_topic): def test_text_topic_editing_permission(text_topic):
"""Ensure a text topic's owner (and nobody else) is able to edit it.""" """Ensure a text topic's owner (and nobody else) is able to edit it."""
principals = principals_allowed_by_permission(text_topic, 'edit')
principals = principals_allowed_by_permission(text_topic, "edit")
assert principals == {text_topic.user.user_id} assert principals == {text_topic.user.user_id}
def test_link_topic_editing_permission(link_topic): def test_link_topic_editing_permission(link_topic):
"""Ensure that nobody has edit permission on a link topic.""" """Ensure that nobody has edit permission on a link topic."""
principals = principals_allowed_by_permission(link_topic, 'edit')
principals = principals_allowed_by_permission(link_topic, "edit")
assert not principals assert not principals
def test_topic_deleting_permission(text_topic): def test_topic_deleting_permission(text_topic):
"""Ensure that the topic's owner (and nobody else) is able to delete it.""" """Ensure that the topic's owner (and nobody else) is able to delete it."""
principals = principals_allowed_by_permission(text_topic, 'delete')
principals = principals_allowed_by_permission(text_topic, "delete")
assert principals == {text_topic.user.user_id} assert principals == {text_topic.user.user_id}
def test_topic_view_author_permission(text_topic): def test_topic_view_author_permission(text_topic):
"""Ensure anyone can view a topic's author normally.""" """Ensure anyone can view a topic's author normally."""
principals = principals_allowed_by_permission(text_topic, 'view_author')
principals = principals_allowed_by_permission(text_topic, "view_author")
assert Everyone in principals assert Everyone in principals
def test_removed_topic_view_author_permission(topic): def test_removed_topic_view_author_permission(topic):
"""Ensure only admins and the author can view a removed topic's author.""" """Ensure only admins and the author can view a removed topic's author."""
topic.is_removed = True topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_author')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_author")
assert principals == {"admin", topic.user_id}
def test_topic_view_content_permission(text_topic): def test_topic_view_content_permission(text_topic):
"""Ensure anyone can view a topic's content normally.""" """Ensure anyone can view a topic's content normally."""
principals = principals_allowed_by_permission(text_topic, 'view_content')
principals = principals_allowed_by_permission(text_topic, "view_content")
assert Everyone in principals assert Everyone in principals
def test_removed_topic_view_content_permission(topic): def test_removed_topic_view_content_permission(topic):
"""Ensure only admins and the author can view a removed topic's content.""" """Ensure only admins and the author can view a removed topic's content."""
topic.is_removed = True topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_content')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_content")
assert principals == {"admin", topic.user_id}
def test_topic_comment_permission(text_topic): def test_topic_comment_permission(text_topic):
"""Ensure authed users have comment perms on a topic by default.""" """Ensure authed users have comment perms on a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'comment')
principals = principals_allowed_by_permission(text_topic, "comment")
assert Authenticated in principals assert Authenticated in principals
def test_locked_topic_comment_permission(topic): def test_locked_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on locked topics.""" """Ensure only admins can post (top-level) comments on locked topics."""
topic.is_locked = True topic.is_locked = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}
def test_removed_topic_comment_permission(topic): def test_removed_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on removed topics.""" """Ensure only admins can post (top-level) comments on removed topics."""
topic.is_removed = True topic.is_removed = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}

22
tildes/tests/test_topic_tags.py

@ -1,34 +1,34 @@
def test_tags_whitespace_stripped(text_topic): def test_tags_whitespace_stripped(text_topic):
"""Ensure excess whitespace around tags gets stripped.""" """Ensure excess whitespace around tags gets stripped."""
text_topic.tags = [' one', 'two ', ' three ']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = [" one", "two ", " three "]
assert text_topic.tags == ["one", "two", "three"]
def test_tag_space_replacement(text_topic): def test_tag_space_replacement(text_topic):
"""Ensure spaces in tags are converted to underscores internally.""" """Ensure spaces in tags are converted to underscores internally."""
text_topic.tags = ['one two', 'three four five']
assert text_topic._tags == ['one_two', 'three_four_five']
text_topic.tags = ["one two", "three four five"]
assert text_topic._tags == ["one_two", "three_four_five"]
def test_tag_consecutive_spaces(text_topic): def test_tag_consecutive_spaces(text_topic):
"""Ensure consecutive spaces/underscores in tags are removed.""" """Ensure consecutive spaces/underscores in tags are removed."""
text_topic.tags = ["one two", "three four", "five __ six"] text_topic.tags = ["one two", "three four", "five __ six"]
assert text_topic.tags == ['one two', 'three four', 'five six']
assert text_topic.tags == ["one two", "three four", "five six"]
def test_duplicate_tags_removed(text_topic): def test_duplicate_tags_removed(text_topic):
"""Ensure duplicate tags are removed (case-insensitive).""" """Ensure duplicate tags are removed (case-insensitive)."""
text_topic.tags = ['one', 'one', 'One', 'ONE', 'two', 'TWO']
assert text_topic.tags == ['one', 'two']
text_topic.tags = ["one", "one", "One", "ONE", "two", "TWO"]
assert text_topic.tags == ["one", "two"]
def test_empty_tags_removed(text_topic): def test_empty_tags_removed(text_topic):
"""Ensure empty tags are removed.""" """Ensure empty tags are removed."""
text_topic.tags = ['', ' ', '_', 'one']
assert text_topic.tags == ['one']
text_topic.tags = ["", " ", "_", "one"]
assert text_topic.tags == ["one"]
def test_tags_lowercased(text_topic): def test_tags_lowercased(text_topic):
"""Ensure tags get converted to lowercase.""" """Ensure tags get converted to lowercase."""
text_topic.tags = ['ONE', 'Two', 'thRee']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = ["ONE", "Two", "thRee"]
assert text_topic.tags == ["one", "two", "three"]

6
tildes/tests/test_triggers_comments.py

@ -8,7 +8,7 @@ def test_comments_affect_topic_num_comments(session_user, topic, db):
# Insert some comments, ensure each one increments the count # Insert some comments, ensure each one increments the count
comments = [] comments = []
for num in range(0, 5): for num in range(0, 5):
new_comment = Comment(topic, session_user, 'comment')
new_comment = Comment(topic, session_user, "comment")
comments.append(new_comment) comments.append(new_comment)
db.add(new_comment) db.add(new_comment)
db.commit() db.commit()
@ -62,8 +62,8 @@ def test_remove_sets_removed_time(db, comment):
def test_remove_delete_single_decrement(db, topic, session_user): def test_remove_delete_single_decrement(db, topic, session_user):
"""Ensure that remove+delete doesn't double-decrement num_comments.""" """Ensure that remove+delete doesn't double-decrement num_comments."""
# add 2 comments # add 2 comments
comment1 = Comment(topic, session_user, 'Comment 1')
comment2 = Comment(topic, session_user, 'Comment 2')
comment1 = Comment(topic, session_user, "Comment 1")
comment2 = Comment(topic, session_user, "Comment 2")
db.add_all([comment1, comment2]) db.add_all([comment1, comment2])
db.commit() db.commit()
db.refresh(topic) db.refresh(topic)

22
tildes/tests/test_url.py

@ -5,13 +5,13 @@ from tildes.lib.url import get_domain_from_url
def test_simple_get_domain(): def test_simple_get_domain():
"""Ensure getting the domain from a normal URL works.""" """Ensure getting the domain from a normal URL works."""
url = 'http://example.com/some/path?query=param&query2=val2'
assert get_domain_from_url(url) == 'example.com'
url = "http://example.com/some/path?query=param&query2=val2"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_non_url(): def test_get_domain_non_url():
"""Ensure attempting to get the domain for a non-url is an error.""" """Ensure attempting to get the domain for a non-url is an error."""
url = 'this is not a url'
url = "this is not a url"
with raises(ValueError): with raises(ValueError):
get_domain_from_url(url) get_domain_from_url(url)
@ -19,27 +19,27 @@ def test_get_domain_non_url():
def test_get_domain_no_scheme(): def test_get_domain_no_scheme():
"""Ensure getting domain on a url with no scheme is an error.""" """Ensure getting domain on a url with no scheme is an error."""
with raises(ValueError): with raises(ValueError):
get_domain_from_url('example.com/something')
get_domain_from_url("example.com/something")
def test_get_domain_explicit_no_scheme(): def test_get_domain_explicit_no_scheme():
"""Ensure getting domain works if url is explicit about lack of scheme.""" """Ensure getting domain works if url is explicit about lack of scheme."""
assert get_domain_from_url('//example.com/something') == 'example.com'
assert get_domain_from_url("//example.com/something") == "example.com"
def test_get_domain_strip_www(): def test_get_domain_strip_www():
"""Ensure stripping the "www." from the domain works as expected.""" """Ensure stripping the "www." from the domain works as expected."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url) == 'example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_no_strip_www(): def test_get_domain_no_strip_www():
"""Ensure stripping the "www." can be disabled.""" """Ensure stripping the "www." can be disabled."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url, strip_www=False) == 'www.example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url, strip_www=False) == "www.example.com"
def test_get_domain_subdomain_not_stripped(): def test_get_domain_subdomain_not_stripped():
"""Ensure a non-www subdomain isn't stripped.""" """Ensure a non-www subdomain isn't stripped."""
url = 'http://something.example.com/path/x/y/z'
assert get_domain_from_url(url) == 'something.example.com'
url = "http://something.example.com/path/x/y/z"
assert get_domain_from_url(url) == "something.example.com"

51
tildes/tests/test_user.py

@ -8,55 +8,55 @@ from tildes.schemas.user import PASSWORD_MIN_LENGTH, UserSchema
def test_creation_validates_schema(mocker): def test_creation_validates_schema(mocker):
"""Ensure that model creation goes through schema validation.""" """Ensure that model creation goes through schema validation."""
mocker.spy(UserSchema, 'validate')
User('testing', 'testpassword')
mocker.spy(UserSchema, "validate")
User("testing", "testpassword")
call_args = [call[0] for call in UserSchema.validate.call_args_list] call_args = [call[0] for call in UserSchema.validate.call_args_list]
expected_args = {'username': 'testing', 'password': 'testpassword'}
expected_args = {"username": "testing", "password": "testpassword"}
assert any(expected_args in call for call in call_args) assert any(expected_args in call for call in call_args)
def test_too_short_password(): def test_too_short_password():
"""Ensure a new user can't be created with a too-short password.""" """Ensure a new user can't be created with a too-short password."""
password = 'x' * (PASSWORD_MIN_LENGTH - 1)
password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError): with raises(ValidationError):
User('ShortPasswordGuy', password)
User("ShortPasswordGuy", password)
def test_matching_password_and_username(): def test_matching_password_and_username():
"""Ensure a new user can't be created with same username and password.""" """Ensure a new user can't be created with same username and password."""
with raises(ValidationError): with raises(ValidationError):
User('UnimaginativePassword', 'UnimaginativePassword')
User("UnimaginativePassword", "UnimaginativePassword")
def test_username_and_password_differ_in_casing(): def test_username_and_password_differ_in_casing():
"""Ensure a user can't be created with name/pass the same except case.""" """Ensure a user can't be created with name/pass the same except case."""
with raises(ValidationError): with raises(ValidationError):
User('NobodyWillGuess', 'nobodywillguess')
User("NobodyWillGuess", "nobodywillguess")
def test_username_contained_in_password(): def test_username_contained_in_password():
"""Ensure a user can't be created with the username in the password.""" """Ensure a user can't be created with the username in the password."""
with raises(ValidationError): with raises(ValidationError):
User('MyUsername', 'iputmyusernameinmypassword')
User("MyUsername", "iputmyusernameinmypassword")
def test_password_contained_in_username(): def test_password_contained_in_username():
"""Ensure a user can't be created with the password in the username.""" """Ensure a user can't be created with the password in the username."""
with raises(ValidationError): with raises(ValidationError):
User('PasswordIsVeryGood', 'VeryGood')
User("PasswordIsVeryGood", "VeryGood")
def test_user_password_check(): def test_user_password_check():
"""Ensure checking the password for a new user works correctly.""" """Ensure checking the password for a new user works correctly."""
new_user = User('myusername', 'mypassword')
assert new_user.is_correct_password('mypassword')
new_user = User("myusername", "mypassword")
assert new_user.is_correct_password("mypassword")
def test_duplicate_username(db): def test_duplicate_username(db):
"""Ensure two users with the same name can't be created.""" """Ensure two users with the same name can't be created."""
original = User('Inimitable', 'securepassword')
original = User("Inimitable", "securepassword")
db.add(original) db.add(original)
duplicate = User('Inimitable', 'adifferentpassword')
duplicate = User("Inimitable", "adifferentpassword")
db.add(duplicate) db.add(duplicate)
with raises(IntegrityError): with raises(IntegrityError):
@ -65,10 +65,10 @@ def test_duplicate_username(db):
def test_duplicate_username_case_insensitive(db): def test_duplicate_username_case_insensitive(db):
"""Ensure usernames only differing in casing can't be created.""" """Ensure usernames only differing in casing can't be created."""
test_username = 'test_user'
original = User(test_username.lower(), 'hackproof')
test_username = "test_user"
original = User(test_username.lower(), "hackproof")
db.add(original) db.add(original)
duplicate = User(test_username.upper(), 'sosecure')
duplicate = User(test_username.upper(), "sosecure")
db.add(duplicate) db.add(duplicate)
with raises(IntegrityError): with raises(IntegrityError):
@ -77,20 +77,20 @@ def test_duplicate_username_case_insensitive(db):
def test_change_password(): def test_change_password():
"""Ensure changing a user password works as expected.""" """Ensure changing a user password works as expected."""
new_user = User('A_New_User', 'lovesexsecretgod')
new_user = User("A_New_User", "lovesexsecretgod")
new_user.change_password('lovesexsecretgod', 'lovesexsecretgod1')
new_user.change_password("lovesexsecretgod", "lovesexsecretgod1")
# the old one shouldn't work # the old one shouldn't work
assert not new_user.is_correct_password('lovesexsecretgod')
assert not new_user.is_correct_password("lovesexsecretgod")
# the new one should # the new one should
assert new_user.is_correct_password('lovesexsecretgod1')
assert new_user.is_correct_password("lovesexsecretgod1")
def test_change_password_to_same(session_user): def test_change_password_to_same(session_user):
"""Ensure users can't "change" to the same password.""" """Ensure users can't "change" to the same password."""
password = 'session user password'
password = "session user password"
with raises(ValueError): with raises(ValueError):
session_user.change_password(password, password) session_user.change_password(password, password)
@ -98,18 +98,17 @@ def test_change_password_to_same(session_user):
def test_change_password_wrong_old_one(session_user): def test_change_password_wrong_old_one(session_user):
"""Ensure changing password doesn't work if the old one is wrong.""" """Ensure changing password doesn't work if the old one is wrong."""
with raises(ValueError): with raises(ValueError):
session_user.change_password('definitely not right', 'some new one')
session_user.change_password("definitely not right", "some new one")
def test_change_password_too_short(session_user): def test_change_password_too_short(session_user):
"""Ensure users can't change password to a too-short one.""" """Ensure users can't change password to a too-short one."""
new_password = 'x' * (PASSWORD_MIN_LENGTH - 1)
new_password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError): with raises(ValidationError):
session_user.change_password('session user password', new_password)
session_user.change_password("session user password", new_password)
def test_change_password_to_username(session_user): def test_change_password_to_username(session_user):
"""Ensure users can't change password to the same as their username.""" """Ensure users can't change password to the same as their username."""
with raises(ValidationError): with raises(ValidationError):
session_user.change_password(
'session user password', session_user.username)
session_user.change_password("session user password", session_user.username)

16
tildes/tests/test_username.py

@ -10,7 +10,7 @@ from tildes.schemas.user import (
def test_too_short_invalid(): def test_too_short_invalid():
"""Ensure too-short username is invalid.""" """Ensure too-short username is invalid."""
length = USERNAME_MIN_LENGTH - 1 length = USERNAME_MIN_LENGTH - 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username) assert not is_valid_username(username)
@ -18,7 +18,7 @@ def test_too_short_invalid():
def test_too_long_invalid(): def test_too_long_invalid():
"""Ensure too-long username is invalid.""" """Ensure too-long username is invalid."""
length = USERNAME_MAX_LENGTH + 1 length = USERNAME_MAX_LENGTH + 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username) assert not is_valid_username(username)
@ -26,22 +26,22 @@ def test_too_long_invalid():
def test_valid_length_range(): def test_valid_length_range():
"""Ensure the entire range of valid lengths work.""" """Ensure the entire range of valid lengths work."""
for length in range(USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH + 1): for length in range(USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH + 1):
username = 'x' * length
username = "x" * length
assert is_valid_username(username) assert is_valid_username(username)
def test_consecutive_spacer_chars_invalid(): def test_consecutive_spacer_chars_invalid():
"""Ensure that a username with consecutive "spacer chars" is invalid.""" """Ensure that a username with consecutive "spacer chars" is invalid."""
spacer_chars = '_-'
spacer_chars = "_-"
for char1, char2 in product(spacer_chars, spacer_chars): for char1, char2 in product(spacer_chars, spacer_chars):
username = f'abc{char1}{char2}xyz'
username = f"abc{char1}{char2}xyz"
assert not is_valid_username(username) assert not is_valid_username(username)
def test_typical_username_valid(): def test_typical_username_valid():
"""Ensure a "normal-looking" username is considered valid.""" """Ensure a "normal-looking" username is considered valid."""
assert is_valid_username('someTypical_user-85')
assert is_valid_username("someTypical_user-85")
def test_invalid_characters(): def test_invalid_characters():
@ -49,11 +49,11 @@ def test_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,.?/' invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,.?/'
for char in invalid_chars: for char in invalid_chars:
username = f'abc{char}xyz'
username = f"abc{char}xyz"
assert not is_valid_username(username) assert not is_valid_username(username)
def test_unicode_characters(): def test_unicode_characters():
"""Ensure that unicode chars can't be included (not comprehensive).""" """Ensure that unicode chars can't be included (not comprehensive)."""
for username in ('pokémon', 'ポケモン', 'møøse'):
for username in ("pokémon", "ポケモン", "møøse"):
assert not is_valid_username(username) assert not is_valid_username(username)

10
tildes/tests/test_webassets.py

@ -1,22 +1,22 @@
from webassets.loaders import YAMLLoader from webassets.loaders import YAMLLoader
WEBASSETS_ENV = YAMLLoader('webassets.yaml').load_environment()
WEBASSETS_ENV = YAMLLoader("webassets.yaml").load_environment()
def test_scripts_file_first_in_bundle(): def test_scripts_file_first_in_bundle():
"""Ensure that the main scripts.js file will be at the top.""" """Ensure that the main scripts.js file will be at the top."""
js_bundle = WEBASSETS_ENV['javascript']
js_bundle = WEBASSETS_ENV["javascript"]
first_filename = js_bundle.resolve_contents()[0][0] first_filename = js_bundle.resolve_contents()[0][0]
assert first_filename == 'js/scripts.js'
assert first_filename == "js/scripts.js"
def test_styles_file_last_in_bundle(): def test_styles_file_last_in_bundle():
"""Ensure that the main styles.css file will be at the bottom.""" """Ensure that the main styles.css file will be at the bottom."""
css_bundle = WEBASSETS_ENV['css']
css_bundle = WEBASSETS_ENV["css"]
last_filename = css_bundle.resolve_contents()[-1][0] last_filename = css_bundle.resolve_contents()[-1][0]
assert last_filename == 'css/styles.css'
assert last_filename == "css/styles.css"

6
tildes/tests/webtests/test_user_page.py

@ -6,8 +6,10 @@ def test_loggedout_username_leak(webtest_loggedout, session_user):
particular username exists or not. particular username exists or not.
""" """
existing_user = webtest_loggedout.get( existing_user = webtest_loggedout.get(
'/user/' + session_user.username, expect_errors=True)
"/user/" + session_user.username, expect_errors=True
)
nonexistent_user = webtest_loggedout.get( nonexistent_user = webtest_loggedout.get(
'/user/thisdoesntexist', expect_errors=True)
"/user/thisdoesntexist", expect_errors=True
)
assert existing_user.status == nonexistent_user.status assert existing_user.status == nonexistent_user.status

93
tildes/tildes/__init__.py

@ -16,51 +16,47 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
"""Configure and return a Pyramid WSGI application.""" """Configure and return a Pyramid WSGI application."""
config = Configurator(settings=settings) config = Configurator(settings=settings)
config.include('cornice')
config.include('pyramid_session_redis')
config.include('pyramid_webassets')
config.include("cornice")
config.include("pyramid_session_redis")
config.include("pyramid_webassets")
# include database first so the session and querying are available # include database first so the session and querying are available
config.include('tildes.database')
config.include('tildes.auth')
config.include('tildes.jinja')
config.include('tildes.json')
config.include('tildes.routes')
config.include("tildes.database")
config.include("tildes.auth")
config.include("tildes.jinja")
config.include("tildes.json")
config.include("tildes.routes")
config.add_webasset('javascript', Bundle(output='js/tildes.js'))
config.add_webasset(
'javascript-third-party', Bundle(output='js/third_party.js'))
config.add_webasset('css', Bundle(output='css/tildes.css'))
config.add_webasset('site-icons-css', Bundle(output='css/site-icons.css'))
config.add_webasset("javascript", Bundle(output="js/tildes.js"))
config.add_webasset("javascript-third-party", Bundle(output="js/third_party.js"))
config.add_webasset("css", Bundle(output="css/tildes.css"))
config.add_webasset("site-icons-css", Bundle(output="css/site-icons.css"))
config.scan('tildes.views')
config.scan("tildes.views")
config.add_tween('tildes.http_method_tween_factory')
config.add_tween("tildes.http_method_tween_factory")
config.add_request_method(
is_safe_request_method, 'is_safe_method', reify=True)
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True)
# Add the request.redis request method to access a redis connection. This # Add the request.redis request method to access a redis connection. This
# is done in a bit of a strange way to support being overridden in tests. # is done in a bit of a strange way to support being overridden in tests.
config.registry['redis_connection_factory'] = get_redis_connection
config.registry["redis_connection_factory"] = get_redis_connection
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
config.add_request_method( config.add_request_method(
lambda request: config.registry['redis_connection_factory'](request),
'redis',
lambda request: config.registry["redis_connection_factory"](request),
"redis",
reify=True, reify=True,
) )
# pylint: enable=unnecessary-lambda # pylint: enable=unnecessary-lambda
config.add_request_method(check_rate_limit, 'check_rate_limit')
config.add_request_method(check_rate_limit, "check_rate_limit")
config.add_request_method(
current_listing_base_url, 'current_listing_base_url')
config.add_request_method(
current_listing_normal_url, 'current_listing_normal_url')
config.add_request_method(current_listing_base_url, "current_listing_base_url")
config.add_request_method(current_listing_normal_url, "current_listing_normal_url")
app = config.make_wsgi_app() app = config.make_wsgi_app()
force_port = global_config.get('prefixmiddleware_force_port')
force_port = global_config.get("prefixmiddleware_force_port")
if force_port: if force_port:
prefixed_app = PrefixMiddleware(app, force_port=force_port) prefixed_app = PrefixMiddleware(app, force_port=force_port)
else: else:
@ -69,19 +65,17 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
return prefixed_app return prefixed_app
def http_method_tween_factory(
handler: Callable,
registry: Registry,
) -> Callable:
def http_method_tween_factory(handler: Callable, registry: Registry) -> Callable:
# pylint: disable=unused-argument # pylint: disable=unused-argument
"""Return a tween function that can override the request's HTTP method.""" """Return a tween function that can override the request's HTTP method."""
def method_override_tween(request: Request) -> Request: def method_override_tween(request: Request) -> Request:
"""Override HTTP method with one specified in header.""" """Override HTTP method with one specified in header."""
valid_overrides_by_method = {'POST': ['DELETE', 'PATCH', 'PUT']}
valid_overrides_by_method = {"POST": ["DELETE", "PATCH", "PUT"]}
original_method = request.method.upper() original_method = request.method.upper()
valid_overrides = valid_overrides_by_method.get(original_method, []) valid_overrides = valid_overrides_by_method.get(original_method, [])
override = request.headers.get('X-HTTP-Method-Override', '').upper()
override = request.headers.get("X-HTTP-Method-Override", "").upper()
if override in valid_overrides: if override in valid_overrides:
request.method = override request.method = override
@ -93,13 +87,13 @@ def http_method_tween_factory(
def get_redis_connection(request: Request) -> StrictRedis: def get_redis_connection(request: Request) -> StrictRedis:
"""Return a StrictRedis connection to the Redis server.""" """Return a StrictRedis connection to the Redis server."""
socket = request.registry.settings['redis.unix_socket_path']
socket = request.registry.settings["redis.unix_socket_path"]
return StrictRedis(unix_socket_path=socket) return StrictRedis(unix_socket_path=socket)
def is_safe_request_method(request: Request) -> bool: def is_safe_request_method(request: Request) -> bool:
"""Return whether the request method is "safe" (is GET or HEAD).""" """Return whether the request method is "safe" (is GET or HEAD)."""
return request.method in {'GET', 'HEAD'}
return request.method in {"GET", "HEAD"}
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
@ -107,7 +101,7 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
try: try:
action = RATE_LIMITED_ACTIONS[action_name] action = RATE_LIMITED_ACTIONS[action_name]
except KeyError: except KeyError:
raise ValueError('Invalid action name: %s' % action_name)
raise ValueError("Invalid action name: %s" % action_name)
action.redis = request.redis action.redis = request.redis
@ -127,8 +121,7 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
def current_listing_base_url( def current_listing_base_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str: ) -> str:
"""Return the "base" url for the current listing route. """Return the "base" url for the current listing route.
@ -137,14 +130,12 @@ def current_listing_base_url(
The `query` argument allows adding query variables to the generated url. The `query` argument allows adding query variables to the generated url.
""" """
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
base_view_vars = (
'order', 'period', 'per_page', 'tag', 'type', 'unfiltered')
base_view_vars = ("order", "period", "per_page", "tag", "type", "unfiltered")
query_vars = { query_vars = {
key: val for key, val in request.GET.copy().items()
if key in base_view_vars
key: val for key, val in request.GET.copy().items() if key in base_view_vars
} }
if query: if query:
query_vars.update(query) query_vars.update(query)
@ -152,12 +143,11 @@ def current_listing_base_url(
url = request.current_route_url(_query=query_vars) url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that # Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")
def current_listing_normal_url( def current_listing_normal_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str: ) -> str:
"""Return the "normal" url for the current listing route. """Return the "normal" url for the current listing route.
@ -166,13 +156,12 @@ def current_listing_normal_url(
The `query` argument allows adding query variables to the generated url. The `query` argument allows adding query variables to the generated url.
""" """
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
normal_view_vars = ('order', 'period', 'per_page')
normal_view_vars = ("order", "period", "per_page")
query_vars = { query_vars = {
key: val for key, val in request.GET.copy().items()
if key in normal_view_vars
key: val for key, val in request.GET.copy().items() if key in normal_view_vars
} }
if query: if query:
query_vars.update(query) query_vars.update(query)
@ -180,4 +169,4 @@ def current_listing_normal_url(
url = request.current_route_url(_query=query_vars) url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that # Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")

6
tildes/tildes/api.py

@ -9,8 +9,8 @@ import venusian
class APIv0(Service): class APIv0(Service):
"""Service wrapper class for v0 of the API.""" """Service wrapper class for v0 of the API."""
name_prefix = 'apiv0_'
base_path = '/api/v0'
name_prefix = "apiv0_"
base_path = "/api/v0"
def __init__(self, name: str, path: str, **kwargs: Any) -> None: def __init__(self, name: str, path: str, **kwargs: Any) -> None:
"""Create a new service.""" """Create a new service."""
@ -28,4 +28,4 @@ class APIv0(Service):
# TEMP: disable API until I can fix the private-fields issue # TEMP: disable API until I can fix the private-fields issue
# config.add_cornice_service(self) # config.add_cornice_service(self)
info = venusian.attach(self, callback, category='pyramid')
info = venusian.attach(self, callback, category="pyramid")

46
tildes/tildes/auth.py

@ -7,13 +7,7 @@ from pyramid.authorization import ACLAuthorizationPolicy
from pyramid.config import Configurator from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPFound from pyramid.httpexceptions import HTTPFound
from pyramid.request import Request from pyramid.request import Request
from pyramid.security import (
ACLDenied,
ACLPermitsResult,
Allow,
Authenticated,
Everyone,
)
from pyramid.security import ACLDenied, ACLPermitsResult, Allow, Authenticated, Everyone
from tildes.models.user import User from tildes.models.user import User
@ -27,7 +21,7 @@ class DefaultRootFactory:
an __acl__ defined, they will not "fall back" to this one. an __acl__ defined, they will not "fall back" to this one.
""" """
__acl__ = ((Allow, Everyone, 'view'),)
__acl__ = ((Allow, Everyone, "view"),)
def __init__(self, request: Request) -> None: def __init__(self, request: Request) -> None:
"""Root factory constructor - must take a request argument.""" """Root factory constructor - must take a request argument."""
@ -40,10 +34,7 @@ def get_authenticated_user(request: Request) -> Optional[User]:
if not user_id: if not user_id:
return None return None
query = (
request.query(User)
.filter_by(user_id=user_id)
)
query = request.query(User).filter_by(user_id=user_id)
return query.one_or_none() return query.one_or_none()
@ -60,15 +51,15 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
# if the user is banned, log them out - is there a better place to do this? # if the user is banned, log them out - is there a better place to do this?
if request.user.is_banned: if request.user.is_banned:
request.session.invalidate() request.session.invalidate()
raise HTTPFound('/')
raise HTTPFound("/")
if user_id != request.user.user_id: if user_id != request.user.user_id:
raise AssertionError('auth_callback called with different user_id')
raise AssertionError("auth_callback called with different user_id")
principals = [] principals = []
if request.user.is_admin: if request.user.is_admin:
principals.append('admin')
principals.append("admin")
return principals return principals
@ -76,7 +67,7 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
def includeme(config: Configurator) -> None: def includeme(config: Configurator) -> None:
"""Config updates related to authentication/authorization.""" """Config updates related to authentication/authorization."""
# make all views require "view" permission unless specifically overridden # make all views require "view" permission unless specifically overridden
config.set_default_permission('view')
config.set_default_permission("view")
# replace the default root factory with a custom one to more easily support # replace the default root factory with a custom one to more easily support
# the default permission # the default permission
@ -89,32 +80,30 @@ def includeme(config: Configurator) -> None:
config.set_authorization_policy(AuthorizedOnlyPolicy()) config.set_authorization_policy(AuthorizedOnlyPolicy())
config.set_authentication_policy( config.set_authentication_policy(
SessionAuthenticationPolicy(callback=auth_callback))
SessionAuthenticationPolicy(callback=auth_callback)
)
# enable CSRF checking globally by default # enable CSRF checking globally by default
config.set_default_csrf_options(require_csrf=True) config.set_default_csrf_options(require_csrf=True)
# make the logged-in User object available as request.user # make the logged-in User object available as request.user
config.add_request_method(get_authenticated_user, 'user', reify=True)
config.add_request_method(get_authenticated_user, "user", reify=True)
# add has_any_permission method for easily checking multiple permissions # add has_any_permission method for easily checking multiple permissions
config.add_request_method(has_any_permission, 'has_any_permission')
config.add_request_method(has_any_permission, "has_any_permission")
class AuthorizedOnlyPolicy(ACLAuthorizationPolicy): class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
"""ACLAuthorizationPolicy override that always denies logged-out users.""" """ACLAuthorizationPolicy override that always denies logged-out users."""
def permits( def permits(
self,
context: Any,
principals: Sequence[Any],
permission: str,
self, context: Any, principals: Sequence[Any], permission: str
) -> ACLPermitsResult: ) -> ACLPermitsResult:
"""Deny logged-out users, otherwise pass up to normal policy.""" """Deny logged-out users, otherwise pass up to normal policy."""
if Authenticated not in principals: if Authenticated not in principals:
return ACLDenied( return ACLDenied(
'<authorized only>',
'<no ACLs checked yet>',
"<authorized only>",
"<no ACLs checked yet>",
permission, permission,
principals, principals,
context, context,
@ -124,12 +113,9 @@ class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
def has_any_permission( def has_any_permission(
request: Request,
permissions: Sequence[str],
context: Any,
request: Request, permissions: Sequence[str], context: Any
) -> bool: ) -> bool:
"""Return whether the user has any of the permissions on the item.""" """Return whether the user has any of the permissions on the item."""
return any( return any(
request.has_permission(permission, context)
for permission in permissions
request.has_permission(permission, context) for permission in permissions
) )

27
tildes/tildes/database.py

@ -28,10 +28,7 @@ def obtain_lock(request: Request, lock_space: str, lock_value: int) -> None:
obtain_transaction_lock(request.db_session, lock_space, lock_value) obtain_transaction_lock(request.db_session, lock_space, lock_value)
def query_factory(
request: Request,
model_cls: Type[DatabaseModel],
) -> ModelQuery:
def query_factory(request: Request, model_cls: Type[DatabaseModel]) -> ModelQuery:
"""Return a ModelQuery or subclass depending on model_cls specified.""" """Return a ModelQuery or subclass depending on model_cls specified."""
if model_cls == Comment: if model_cls == Comment:
return CommentQuery(request) return CommentQuery(request)
@ -46,8 +43,7 @@ def query_factory(
def get_tm_session( def get_tm_session(
session_factory: Callable,
transaction_manager: ThreadTransactionManager,
session_factory: Callable, transaction_manager: ThreadTransactionManager
) -> Session: ) -> Session:
"""Return a db session being managed by the transaction manager.""" """Return a db session being managed by the transaction manager."""
db_session = session_factory() db_session = session_factory()
@ -74,26 +70,27 @@ def includeme(config: Configurator) -> None:
# transaction if the response code starts with 4 or 5. The main benefit of # transaction if the response code starts with 4 or 5. The main benefit of
# this is to avoid aborting on exceptions that don't actually indicate a # this is to avoid aborting on exceptions that don't actually indicate a
# problem, such as a HTTPFound 302 redirect. # problem, such as a HTTPFound 302 redirect.
settings['tm.commit_veto'] = 'pyramid_tm.default_commit_veto'
settings["tm.commit_veto"] = "pyramid_tm.default_commit_veto"
config.include('pyramid_tm')
config.include("pyramid_tm")
# disable SQLAlchemy connection pooling since pgbouncer will handle it # disable SQLAlchemy connection pooling since pgbouncer will handle it
settings['sqlalchemy.poolclass'] = NullPool
settings["sqlalchemy.poolclass"] = NullPool
engine = engine_from_config(settings, 'sqlalchemy.')
engine = engine_from_config(settings, "sqlalchemy.")
session_factory = sessionmaker(bind=engine, expire_on_commit=False) session_factory = sessionmaker(bind=engine, expire_on_commit=False)
config.registry['db_session_factory'] = session_factory
config.registry["db_session_factory"] = session_factory
# attach the session to each request as request.db_session # attach the session to each request as request.db_session
config.add_request_method( config.add_request_method(
lambda request: get_tm_session( lambda request: get_tm_session(
config.registry['db_session_factory'], request.tm),
'db_session',
config.registry["db_session_factory"], request.tm
),
"db_session",
reify=True, reify=True,
) )
config.add_request_method(query_factory, 'query')
config.add_request_method(query_factory, "query")
config.add_request_method(obtain_lock, 'obtain_lock')
config.add_request_method(obtain_lock, "obtain_lock")

20
tildes/tildes/enums.py

@ -21,12 +21,12 @@ class CommentSortOption(enum.Enum):
@property @property
def description(self) -> str: def description(self) -> str:
"""Describe this sort option.""" """Describe this sort option."""
if self.name == 'NEWEST':
return 'newest first'
elif self.name == 'POSTED':
return 'order posted'
if self.name == "NEWEST":
return "newest first"
elif self.name == "POSTED":
return "order posted"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class CommentTagOption(enum.Enum): class CommentTagOption(enum.Enum):
@ -72,12 +72,12 @@ class TopicSortOption(enum.Enum):
using that sort in descending order means that topics with the most using that sort in descending order means that topics with the most
votes will be listed first. votes will be listed first.
""" """
if self.name == 'NEW':
return 'newest'
elif self.name == 'ACTIVITY':
return 'activity'
if self.name == "NEW":
return "newest"
elif self.name == "ACTIVITY":
return "activity"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class TopicType(enum.Enum): class TopicType(enum.Enum):

26
tildes/tildes/jinja.py

@ -29,28 +29,26 @@ def includeme(config: Configurator) -> None:
"""Configure Jinja2 template renderer.""" """Configure Jinja2 template renderer."""
settings = config.get_settings() settings = config.get_settings()
settings['jinja2.lstrip_blocks'] = True
settings['jinja2.trim_blocks'] = True
settings['jinja2.undefined'] = 'strict'
settings["jinja2.lstrip_blocks"] = True
settings["jinja2.trim_blocks"] = True
settings["jinja2.undefined"] = "strict"
# add custom jinja filters # add custom jinja filters
settings['jinja2.filters'] = {
'ago': descriptive_timedelta,
}
settings["jinja2.filters"] = {"ago": descriptive_timedelta}
# add custom jinja tests # add custom jinja tests
settings['jinja2.tests'] = {
'comment': is_comment,
'group': is_group,
'topic': is_topic,
settings["jinja2.tests"] = {
"comment": is_comment,
"group": is_group,
"topic": is_topic,
} }
config.include('pyramid_jinja2')
config.include("pyramid_jinja2")
config.add_jinja2_search_path('tildes:templates/')
config.add_jinja2_search_path("tildes:templates/")
config.add_jinja2_extension('jinja2.ext.do')
config.add_jinja2_extension('webassets.ext.jinja2.AssetsExtension')
config.add_jinja2_extension("jinja2.ext.do")
config.add_jinja2_extension("webassets.ext.jinja2.AssetsExtension")
# attach webassets to jinja2 environment (via scheduled action) # attach webassets to jinja2 environment (via scheduled action)
def attach_webassets_to_jinja2() -> None: def attach_webassets_to_jinja2() -> None:

6
tildes/tildes/json.py

@ -23,8 +23,8 @@ def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
def serialize_topic(topic: Topic, request: Request) -> dict: def serialize_topic(topic: Topic, request: Request) -> dict:
"""Return serializable data for a Topic.""" """Return serializable data for a Topic."""
context = {} context = {}
if not request.has_permission('view_author', topic):
context['hide_username'] = True
if not request.has_permission("view_author", topic):
context["hide_username"] = True
return topic.schema_class(context=context).dump(topic) return topic.schema_class(context=context).dump(topic)
@ -40,4 +40,4 @@ def includeme(config: Configurator) -> None:
# add specific adapters # add specific adapters
json_renderer.add_adapter(Topic, serialize_topic) json_renderer.add_adapter(Topic, serialize_topic)
config.add_renderer('json', json_renderer)
config.add_renderer("json", json_renderer)

16
tildes/tildes/lib/amqp.py

@ -24,30 +24,24 @@ class PgsqlQueueConsumer(AbstractConsumer):
JSON format. JSON format.
""" """
PGSQL_EXCHANGE_NAME = 'pgsql_events'
PGSQL_EXCHANGE_NAME = "pgsql_events"
def __init__( def __init__(
self,
queue_name: str,
routing_keys: Sequence[str],
uses_db: bool = True,
self, queue_name: str, routing_keys: Sequence[str], uses_db: bool = True
) -> None: ) -> None:
"""Initialize a new queue, bindings, and consumer for it.""" """Initialize a new queue, bindings, and consumer for it."""
self.connection = Connection() self.connection = Connection()
self.channel = self.connection.channel() self.channel = self.connection.channel()
self.channel.queue_declare(
queue_name, durable=True, auto_delete=False)
self.channel.queue_declare(queue_name, durable=True, auto_delete=False)
for routing_key in routing_keys: for routing_key in routing_keys:
self.channel.queue_bind( self.channel.queue_bind(
queue_name,
exchange=self.PGSQL_EXCHANGE_NAME,
routing_key=routing_key,
queue_name, exchange=self.PGSQL_EXCHANGE_NAME, routing_key=routing_key
) )
if uses_db: if uses_db:
self.db_session = get_session_from_config(os.environ['INI_FILE'])
self.db_session = get_session_from_config(os.environ["INI_FILE"])
else: else:
self.db_session = None self.db_session = None

12
tildes/tildes/lib/cmark.py

@ -4,14 +4,14 @@
from ctypes import CDLL, c_char_p, c_int, c_size_t, c_void_p from ctypes import CDLL, c_char_p, c_int, c_size_t, c_void_p
CMARK_DLL = CDLL('/usr/local/lib/libcmark-gfm.so')
CMARK_EXT_DLL = CDLL('/usr/local/lib/libcmark-gfmextensions.so')
CMARK_DLL = CDLL("/usr/local/lib/libcmark-gfm.so")
CMARK_EXT_DLL = CDLL("/usr/local/lib/libcmark-gfmextensions.so")
# enables the --hardbreaks option for cmark # enables the --hardbreaks option for cmark
# (can I import this? it's defined in cmark.h as CMARK_OPT_HARDBREAKS) # (can I import this? it's defined in cmark.h as CMARK_OPT_HARDBREAKS)
CMARK_OPTS = 4 CMARK_OPTS = 4
CMARK_EXTENSIONS = (b'strikethrough', b'table')
CMARK_EXTENSIONS = (b"strikethrough", b"table")
cmark_parser_new = CMARK_DLL.cmark_parser_new cmark_parser_new = CMARK_DLL.cmark_parser_new
cmark_parser_new.restype = c_void_p cmark_parser_new.restype = c_void_p
@ -25,13 +25,11 @@ cmark_parser_finish = CMARK_DLL.cmark_parser_finish
cmark_parser_finish.restype = c_void_p cmark_parser_finish.restype = c_void_p
cmark_parser_finish.argtypes = (c_void_p,) cmark_parser_finish.argtypes = (c_void_p,)
cmark_parser_attach_syntax_extension = (
CMARK_DLL.cmark_parser_attach_syntax_extension)
cmark_parser_attach_syntax_extension = CMARK_DLL.cmark_parser_attach_syntax_extension
cmark_parser_attach_syntax_extension.restype = c_int cmark_parser_attach_syntax_extension.restype = c_int
cmark_parser_attach_syntax_extension.argtypes = (c_void_p, c_void_p) cmark_parser_attach_syntax_extension.argtypes = (c_void_p, c_void_p)
cmark_parser_get_syntax_extensions = (
CMARK_DLL.cmark_parser_get_syntax_extensions)
cmark_parser_get_syntax_extensions = CMARK_DLL.cmark_parser_get_syntax_extensions
cmark_parser_get_syntax_extensions.restype = c_void_p cmark_parser_get_syntax_extensions.restype = c_void_p
cmark_parser_get_syntax_extensions.argtypes = (c_void_p,) cmark_parser_get_syntax_extensions.argtypes = (c_void_p,)

26
tildes/tildes/lib/database.py

@ -20,7 +20,7 @@ NOT_NULL_ERROR_CODE = 23502
def get_session_from_config(config_path: str) -> Session: def get_session_from_config(config_path: str) -> Session:
"""Get a database session from a config file (specified by path).""" """Get a database session from a config file (specified by path)."""
env = bootstrap(config_path) env = bootstrap(config_path)
session_factory = env['registry']['db_session_factory']
session_factory = env["registry"]["db_session_factory"]
return session_factory() return session_factory()
@ -31,9 +31,7 @@ class LockSpaces(enum.Enum):
def obtain_transaction_lock( def obtain_transaction_lock(
session: Session,
lock_space: Optional[str],
lock_value: int,
session: Session, lock_space: Optional[str], lock_value: int
) -> None: ) -> None:
"""Obtain a transaction-level advisory lock from PostgreSQL. """Obtain a transaction-level advisory lock from PostgreSQL.
@ -45,11 +43,9 @@ def obtain_transaction_lock(
try: try:
lock_space_value = LockSpaces[lock_space.upper()].value lock_space_value = LockSpaces[lock_space.upper()].value
except KeyError: except KeyError:
raise ValueError('Invalid lock space: %s' % lock_space)
raise ValueError("Invalid lock space: %s" % lock_space)
session.query(
func.pg_advisory_xact_lock(lock_space_value, lock_value)
).one()
session.query(func.pg_advisory_xact_lock(lock_space_value, lock_value)).one()
else: else:
session.query(func.pg_advisory_xact_lock(lock_value)).one() session.query(func.pg_advisory_xact_lock(lock_value)).one()
@ -66,10 +62,11 @@ class CIText(UserDefinedType):
def get_col_spec(self, **kw: Any) -> str: def get_col_spec(self, **kw: Any) -> str:
"""Return the type name (for creating columns and so on).""" """Return the type name (for creating columns and so on)."""
# pylint: disable=no-self-use,unused-argument # pylint: disable=no-self-use,unused-argument
return 'CITEXT'
return "CITEXT"
def bind_processor(self, dialect: Dialect) -> Callable: def bind_processor(self, dialect: Dialect) -> Callable:
"""Return a conversion function for processing bind values.""" """Return a conversion function for processing bind values."""
def process(value: Any) -> Any: def process(value: Any) -> Any:
return value return value
@ -77,6 +74,7 @@ class CIText(UserDefinedType):
def result_processor(self, dialect: Dialect, coltype: Any) -> Callable: def result_processor(self, dialect: Dialect, coltype: Any) -> Callable:
"""Return a conversion function for processing result row values.""" """Return a conversion function for processing result row values."""
def process(value: Any) -> Any: def process(value: Any) -> Any:
return value return value
@ -103,8 +101,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
super_rp = super().result_processor(dialect, coltype) super_rp = super().result_processor(dialect, coltype)
def handle_raw_string(value: str) -> List[str]: def handle_raw_string(value: str) -> List[str]:
if not (value.startswith('{') and value.endswith('}')):
raise ValueError('%s is not an array value' % value)
if not (value.startswith("{") and value.endswith("}")):
raise ValueError("%s is not an array value" % value)
# trim off the surrounding braces # trim off the surrounding braces
value = value[1:-1] value = value[1:-1]
@ -113,7 +111,7 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
if not value: if not value:
return [] return []
return value.split(',')
return value.split(",")
def process(value: Optional[str]) -> Optional[List[str]]: def process(value: Optional[str]) -> Optional[List[str]]:
if value is None: if value is None:
@ -133,8 +131,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
def ancestor_of(self, other): # type: ignore def ancestor_of(self, other): # type: ignore
"""Return whether the array contains any ancestor of `other`.""" """Return whether the array contains any ancestor of `other`."""
return self.op('@>')(other)
return self.op("@>")(other)
def descendant_of(self, other): # type: ignore def descendant_of(self, other): # type: ignore
"""Return whether the array contains any descendant of `other`.""" """Return whether the array contains any descendant of `other`."""
return self.op('<@')(other)
return self.op("<@")(other)

30
tildes/tildes/lib/datetime.py

@ -10,32 +10,32 @@ from ago import human
class SimpleHoursPeriod: class SimpleHoursPeriod:
"""A simple class that represents a time period of hours or days.""" """A simple class that represents a time period of hours or days."""
_SHORT_FORM_REGEX = re.compile(r'\d+[hd]', re.IGNORECASE)
_SHORT_FORM_REGEX = re.compile(r"\d+[hd]", re.IGNORECASE)
def __init__(self, hours: int) -> None: def __init__(self, hours: int) -> None:
"""Initialize a SimpleHoursPeriod from a number of hours.""" """Initialize a SimpleHoursPeriod from a number of hours."""
if hours <= 0: if hours <= 0:
raise ValueError('Period must be at least 1 hour.')
raise ValueError("Period must be at least 1 hour.")
self.hours = hours self.hours = hours
try: try:
self.timedelta = timedelta(hours=hours) self.timedelta = timedelta(hours=hours)
except OverflowError: except OverflowError:
raise ValueError('Time period is too large')
raise ValueError("Time period is too large")
@classmethod @classmethod
def from_short_form(cls, short_form: str) -> 'SimpleHoursPeriod':
def from_short_form(cls, short_form: str) -> "SimpleHoursPeriod":
"""Initialize a period from a "short form" string (e.g. "2h", "4d").""" """Initialize a period from a "short form" string (e.g. "2h", "4d")."""
if not cls._SHORT_FORM_REGEX.match(short_form): if not cls._SHORT_FORM_REGEX.match(short_form):
raise ValueError('Invalid time period')
raise ValueError("Invalid time period")
unit = short_form[-1].lower() unit = short_form[-1].lower()
count = int(short_form[:-1]) count = int(short_form[:-1])
if unit == 'h':
if unit == "h":
hours = count hours = count
elif unit == 'd':
elif unit == "d":
hours = count * 24 hours = count * 24
return cls(hours=hours) return cls(hours=hours)
@ -47,9 +47,9 @@ class SimpleHoursPeriod:
for the special case of exactly "1 day", which is replaced with "24 for the special case of exactly "1 day", which is replaced with "24
hours". hours".
""" """
string = human(self.timedelta, past_tense='{}')
if string == '1 day':
string = '24 hours'
string = human(self.timedelta, past_tense="{}")
if string == "1 day":
string = "24 hours"
return string return string
@ -67,9 +67,9 @@ class SimpleHoursPeriod:
24 hours (except for 24 hours itself). 24 hours (except for 24 hours itself).
""" """
if self.hours % 24 == 0 and self.hours != 24: if self.hours % 24 == 0 and self.hours != 24:
return '{}d'.format(self.hours // 24)
return "{}d".format(self.hours // 24)
return f'{self.hours}h'
return f"{self.hours}h"
def utc_now() -> datetime: def utc_now() -> datetime:
@ -93,7 +93,7 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
""" """
seconds_ago = (utc_now() - target).total_seconds() seconds_ago = (utc_now() - target).total_seconds()
if seconds_ago < 1: if seconds_ago < 1:
return 'a moment ago'
return "a moment ago"
# determine whether one or two precision levels is appropriate # determine whether one or two precision levels is appropriate
if seconds_ago < 3600: if seconds_ago < 3600:
@ -103,7 +103,7 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# try a precision=2 version, and check the units it ends up with # try a precision=2 version, and check the units it ends up with
result = human(target, precision=2) result = human(target, precision=2)
units = ('year', 'day', 'hour', 'minute', 'second')
units = ("year", "day", "hour", "minute", "second")
unit_indices = [i for (i, unit) in enumerate(units) if unit in result] unit_indices = [i for (i, unit) in enumerate(units) if unit in result]
# if there was only one unit in it, or they're adjacent, this is fine # if there was only one unit in it, or they're adjacent, this is fine
@ -117,6 +117,6 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# remove commas if abbreviating ("3d 2h ago", not "3d, 2h ago") # remove commas if abbreviating ("3d 2h ago", not "3d, 2h ago")
if abbreviate: if abbreviate:
result = result.replace(',', '')
result = result.replace(",", "")
return result return result

3
tildes/tildes/lib/hash.py

@ -11,7 +11,8 @@ ARGON2_TIME_COST = 4
ARGON2_MEMORY_COST = 8092 ARGON2_MEMORY_COST = 8092
ARGON2_HASHER = PasswordHasher( ARGON2_HASHER = PasswordHasher(
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST)
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST
)
def hash_string(string: str) -> str: def hash_string(string: str) -> str:

10
tildes/tildes/lib/id.py

@ -4,13 +4,13 @@ import re
import string import string
ID36_REGEX = re.compile('^[a-z0-9]+$', re.IGNORECASE)
ID36_REGEX = re.compile("^[a-z0-9]+$", re.IGNORECASE)
def id_to_id36(id_val: int) -> str: def id_to_id36(id_val: int) -> str:
"""Convert an integer ID to the string ID36 representation.""" """Convert an integer ID to the string ID36 representation."""
if id_val < 1: if id_val < 1:
raise ValueError('ID values should never be zero or negative')
raise ValueError("ID values should never be zero or negative")
reversed_chars = [] reversed_chars = []
@ -29,13 +29,13 @@ def id_to_id36(id_val: int) -> str:
reversed_chars.append(alphabet[index]) reversed_chars.append(alphabet[index])
# join the characters in reversed order and return as the result # join the characters in reversed order and return as the result
return ''.join(reversed(reversed_chars))
return "".join(reversed(reversed_chars))
def id36_to_id(id36_val: str) -> int: def id36_to_id(id36_val: str) -> int:
"""Convert a string ID36 to the integer ID representation.""" """Convert a string ID36 to the integer ID representation."""
if id36_val.startswith('-') or id36_val == '0':
raise ValueError('ID values should never be zero or negative')
if id36_val.startswith("-") or id36_val == "0":
raise ValueError("ID values should never be zero or negative")
# Python's stdlib can handle this, much simpler in this direction # Python's stdlib can handle this, much simpler in this direction
return int(id36_val, 36) return int(id36_val, 36)

181
tildes/tildes/lib/markdown.py

@ -40,51 +40,51 @@ from .cmark import (
HTML_TAG_WHITELIST = ( HTML_TAG_WHITELIST = (
'a',
'b',
'blockquote',
'br',
'code',
'del',
'em',
'h1',
'h2',
'h3',
'h4',
'h5',
'h6',
'hr',
'i',
'ins',
'li',
'ol',
'p',
'pre',
'strong',
'sub',
'sup',
'table',
'tbody',
'td',
'th',
'thead',
'tr',
'ul',
"a",
"b",
"blockquote",
"br",
"code",
"del",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"hr",
"i",
"ins",
"li",
"ol",
"p",
"pre",
"strong",
"sub",
"sup",
"table",
"tbody",
"td",
"th",
"thead",
"tr",
"ul",
) )
HTML_ATTRIBUTE_WHITELIST = { HTML_ATTRIBUTE_WHITELIST = {
'a': ['href', 'title'],
'ol': ['start'],
'td': ['align'],
'th': ['align'],
"a": ["href", "title"],
"ol": ["start"],
"td": ["align"],
"th": ["align"],
} }
PROTOCOL_WHITELIST = ('http', 'https')
PROTOCOL_WHITELIST = ("http", "https")
# Regex that finds ordered list markdown that was probably accidental - ones # Regex that finds ordered list markdown that was probably accidental - ones
# being initiated by anything except "1." # being initiated by anything except "1."
BAD_ORDERED_LIST_REGEX = re.compile( BAD_ORDERED_LIST_REGEX = re.compile(
r'((?:\A|\n\n)' # Either the start of the entire text, or a new paragraph
r'(?!1\.)\d+)' # A number that isn't "1"
r'\.\s', # Followed by a period and a space
r"((?:\A|\n\n)" # Either the start of the entire text, or a new paragraph
r"(?!1\.)\d+)" # A number that isn't "1"
r"\.\s" # Followed by a period and a space
) )
# Type alias for the "namespaced attr dict" used inside bleach.linkify # Type alias for the "namespaced attr dict" used inside bleach.linkify
@ -95,12 +95,11 @@ NamespacedAttrDict = Dict[Union[Tuple[Optional[str], str], str], str] # noqa
def linkify_protocol_whitelist( def linkify_protocol_whitelist(
attrs: NamespacedAttrDict,
new: bool = False,
attrs: NamespacedAttrDict, new: bool = False
) -> Optional[NamespacedAttrDict]: ) -> Optional[NamespacedAttrDict]:
"""bleach.linkify callback: prevent links to non-whitelisted protocols.""" """bleach.linkify callback: prevent links to non-whitelisted protocols."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
href = attrs.get((None, 'href'))
href = attrs.get((None, "href"))
if not href: if not href:
return attrs return attrs
@ -112,13 +111,13 @@ def linkify_protocol_whitelist(
return attrs return attrs
@histogram_timer('markdown_processing')
@histogram_timer("markdown_processing")
def convert_markdown_to_safe_html(markdown: str) -> str: def convert_markdown_to_safe_html(markdown: str) -> str:
"""Convert markdown to sanitized HTML.""" """Convert markdown to sanitized HTML."""
# apply custom pre-processing to markdown # apply custom pre-processing to markdown
markdown = preprocess_markdown(markdown) markdown = preprocess_markdown(markdown)
markdown_bytes = markdown.encode('utf8')
markdown_bytes = markdown.encode("utf8")
parser = cmark_parser_new(CMARK_OPTS) parser = cmark_parser_new(CMARK_OPTS)
for name in CMARK_EXTENSIONS: for name in CMARK_EXTENSIONS:
@ -134,7 +133,7 @@ def convert_markdown_to_safe_html(markdown: str) -> str:
cmark_parser_free(parser) cmark_parser_free(parser)
cmark_node_free(doc) cmark_node_free(doc)
html = html_bytes.decode('utf8')
html = html_bytes.decode("utf8")
# apply custom post-processing to HTML # apply custom post-processing to HTML
html = postprocess_markdown_html(html) html = postprocess_markdown_html(html)
@ -148,7 +147,7 @@ def preprocess_markdown(markdown: str) -> str:
markdown = escape_accidental_ordered_lists(markdown) markdown = escape_accidental_ordered_lists(markdown)
# fix the "shrug" emoji ¯\_(ツ)_/¯ to prevent markdown mangling it # fix the "shrug" emoji ¯\_(ツ)_/¯ to prevent markdown mangling it
markdown = markdown.replace(r'¯\_(ツ)_/¯', r'¯\\\_(ツ)\_/¯')
markdown = markdown.replace(r"¯\_(ツ)_/¯", r"¯\\\_(ツ)\_/¯")
return markdown return markdown
@ -166,19 +165,17 @@ def escape_accidental_ordered_lists(markdown: str) -> str:
numbered list except for "1. ". This will cause a few other edge cases, but numbered list except for "1. ". This will cause a few other edge cases, but
I believe they're less common/important than fixing this common error. I believe they're less common/important than fixing this common error.
""" """
return BAD_ORDERED_LIST_REGEX.sub(r'\1\\. ', markdown)
return BAD_ORDERED_LIST_REGEX.sub(r"\1\\. ", markdown)
def postprocess_markdown_html(html: str) -> str: def postprocess_markdown_html(html: str) -> str:
"""Apply post-processing to HTML generated by markdown parser.""" """Apply post-processing to HTML generated by markdown parser."""
# list of tag names to exclude from linkification # list of tag names to exclude from linkification
linkify_skipped_tags = ['pre']
linkify_skipped_tags = ["pre"]
# search for text that looks like urls and convert to actual links # search for text that looks like urls and convert to actual links
html = bleach.linkify( html = bleach.linkify(
html,
callbacks=[linkify_protocol_whitelist],
skip_tags=linkify_skipped_tags,
html, callbacks=[linkify_protocol_whitelist], skip_tags=linkify_skipped_tags
) )
# run the HTML through our custom linkification process as well # run the HTML through our custom linkification process as well
@ -187,20 +184,17 @@ def postprocess_markdown_html(html: str) -> str:
return html return html
def apply_linkification(
html: str,
skip_tags: Optional[List[str]] = None,
) -> str:
def apply_linkification(html: str, skip_tags: Optional[List[str]] = None) -> str:
"""Apply custom linkification filter to convert text patterns to links.""" """Apply custom linkification filter to convert text patterns to links."""
parser = HTMLParser(namespaceHTMLElements=False) parser = HTMLParser(namespaceHTMLElements=False)
html_tree = parser.parseFragment(html) html_tree = parser.parseFragment(html)
walker_stream = html5lib.getTreeWalker('etree')(html_tree)
walker_stream = html5lib.getTreeWalker("etree")(html_tree)
filtered_html_tree = LinkifyFilter(walker_stream, skip_tags) filtered_html_tree = LinkifyFilter(walker_stream, skip_tags)
serializer = HTMLSerializer( serializer = HTMLSerializer(
quote_attr_values='always',
quote_attr_values="always",
omit_optional_tags=False, omit_optional_tags=False,
sanitize=False, sanitize=False,
alphabetical_attributes=False, alphabetical_attributes=False,
@ -224,17 +218,15 @@ class LinkifyFilter(Filter):
# Note: currently specifically excludes paths immediately followed by a # Note: currently specifically excludes paths immediately followed by a
# tilde, but this may be possible to remove once strikethrough is # tilde, but this may be possible to remove once strikethrough is
# implemented (since that's probably what they were trying to do) # implemented (since that's probably what they were trying to do)
GROUP_REFERENCE_REGEX = re.compile(r'(?<!\w)~([\w.]+)\b(?!~)')
GROUP_REFERENCE_REGEX = re.compile(r"(?<!\w)~([\w.]+)\b(?!~)")
# Regex that finds probable references to users. As above, this isn't # Regex that finds probable references to users. As above, this isn't
# "perfect" either but works as an initial pass with the validity of # "perfect" either but works as an initial pass with the validity of
# the username checked more carefully later. # the username checked more carefully later.
USERNAME_REFERENCE_REGEX = re.compile(r'(?<!\w)(?:/?u/|@)([\w-]+)\b')
USERNAME_REFERENCE_REGEX = re.compile(r"(?<!\w)(?:/?u/|@)([\w-]+)\b")
def __init__( def __init__(
self,
source: NonRecursiveTreeWalker,
skip_tags: Optional[List[str]] = None,
self, source: NonRecursiveTreeWalker, skip_tags: Optional[List[str]] = None
) -> None: ) -> None:
"""Initialize a linkification filter to apply to HTML. """Initialize a linkification filter to apply to HTML.
@ -245,28 +237,30 @@ class LinkifyFilter(Filter):
self.skip_tags = skip_tags or [] self.skip_tags = skip_tags or []
# always skip the contents of <a> tags in addition to any others # always skip the contents of <a> tags in addition to any others
self.skip_tags.append('a')
self.skip_tags.append("a")
def __iter__(self) -> Iterator[dict]: def __iter__(self) -> Iterator[dict]:
"""Iterate over the tree, modifying it as necessary before yielding.""" """Iterate over the tree, modifying it as necessary before yielding."""
inside_skipped_tags = [] inside_skipped_tags = []
for token in super().__iter__(): for token in super().__iter__():
if (token['type'] in ('StartTag', 'EmptyTag') and
token['name'] in self.skip_tags):
if (
token["type"] in ("StartTag", "EmptyTag")
and token["name"] in self.skip_tags
):
# if this is the start of a tag we want to skip, add it to the # if this is the start of a tag we want to skip, add it to the
# list of skipped tags that we're currently inside # list of skipped tags that we're currently inside
inside_skipped_tags.append(token['name'])
inside_skipped_tags.append(token["name"])
elif inside_skipped_tags: elif inside_skipped_tags:
# if we're currently inside any skipped tags, the only thing we # if we're currently inside any skipped tags, the only thing we
# want to do is look for all the end tags we need to be able to # want to do is look for all the end tags we need to be able to
# finish skipping # finish skipping
if token['type'] == 'EndTag':
if token["type"] == "EndTag":
try: try:
inside_skipped_tags.remove(token['name'])
inside_skipped_tags.remove(token["name"])
except ValueError: except ValueError:
pass pass
elif token['type'] == 'Characters':
elif token["type"] == "Characters":
# this is only reachable if inside_skipped_tags is empty, so # this is only reachable if inside_skipped_tags is empty, so
# this is a text token not inside a skipped tag - do the actual # this is a text token not inside a skipped tag - do the actual
# linkification replacements # linkification replacements
@ -300,9 +294,7 @@ class LinkifyFilter(Filter):
@staticmethod @staticmethod
def _linkify_tokens( def _linkify_tokens(
tokens: List[dict],
filter_regex: Pattern,
linkify_function: Callable,
tokens: List[dict], filter_regex: Pattern, linkify_function: Callable
) -> List[dict]: ) -> List[dict]:
"""Check tokens for text that matches a regex and linkify it. """Check tokens for text that matches a regex and linkify it.
@ -316,21 +308,23 @@ class LinkifyFilter(Filter):
for token in tokens: for token in tokens:
# we don't want to touch any tokens other than character ones # we don't want to touch any tokens other than character ones
if token['type'] != 'Characters':
if token["type"] != "Characters":
new_tokens.append(token) new_tokens.append(token)
continue continue
original_text = token['data']
original_text = token["data"]
current_index = 0 current_index = 0
for match in filter_regex.finditer(original_text): for match in filter_regex.finditer(original_text):
# if there were some characters between the previous match and # if there were some characters between the previous match and
# this one, add a token containing those first # this one, add a token containing those first
if match.start() > current_index: if match.start() > current_index:
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:match.start()],
})
new_tokens.append(
{
"type": "Characters",
"data": original_text[current_index : match.start()],
}
)
# call the linkify function to convert this match into tokens # call the linkify function to convert this match into tokens
linkified_tokens = linkify_function(match) linkified_tokens = linkify_function(match)
@ -342,10 +336,9 @@ class LinkifyFilter(Filter):
# if there's still some text left over, add one more token for it # if there's still some text left over, add one more token for it
# (this will be the entire thing if there weren't any matches) # (this will be the entire thing if there weren't any matches)
if current_index < len(original_text): if current_index < len(original_text):
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:],
})
new_tokens.append(
{"type": "Characters", "data": original_text[current_index:]}
)
return new_tokens return new_tokens
@ -360,22 +353,22 @@ class LinkifyFilter(Filter):
# things like "~10" or "~4.5" since that's just going to be someone # things like "~10" or "~4.5" since that's just going to be someone
# using it in the "approximately" sense. So if the path consists of # using it in the "approximately" sense. So if the path consists of
# only numbers and/or periods, we won't linkify it # only numbers and/or periods, we won't linkify it
is_numeric = all(char in '0123456789.' for char in group_path)
is_numeric = all(char in "0123456789." for char in group_path)
# if it's a valid group path and not totally numeric, convert to <a> # if it's a valid group path and not totally numeric, convert to <a>
if is_valid_group_path(group_path) and not is_numeric: if is_valid_group_path(group_path) and not is_numeric:
return [ return [
{ {
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/~{group_path}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/~{group_path}"},
}, },
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
] ]
# one of the checks failed, so just keep it as the original text # one of the checks failed, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
@staticmethod @staticmethod
def _tokenize_username_match(match: Match) -> List[dict]: def _tokenize_username_match(match: Match) -> List[dict]:
@ -384,16 +377,16 @@ class LinkifyFilter(Filter):
if is_valid_username(match[1]): if is_valid_username(match[1]):
return [ return [
{ {
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/user/{match[1]}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/user/{match[1]}"},
}, },
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
] ]
# the username wasn't valid, so just keep it as the original text # the username wasn't valid, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
def sanitize_html(html: str) -> str: def sanitize_html(html: str) -> str:

2
tildes/tildes/lib/message.py

@ -1,6 +1,6 @@
"""Functions/constants related to messages.""" """Functions/constants related to messages."""
WELCOME_MESSAGE_SUBJECT = 'Welcome to the Tildes alpha'
WELCOME_MESSAGE_SUBJECT = "Welcome to the Tildes alpha"
# pylama:ignore=E501 # pylama:ignore=E501
WELCOME_MESSAGE_TEXT = """ WELCOME_MESSAGE_TEXT = """

11
tildes/tildes/lib/password.py

@ -5,21 +5,22 @@ from hashlib import sha1
from redis import ConnectionError, ResponseError, StrictRedis # noqa from redis import ConnectionError, ResponseError, StrictRedis # noqa
# unix socket path for redis server with the breached passwords bloom filter # unix socket path for redis server with the breached passwords bloom filter
BREACHED_PASSWORDS_REDIS_SOCKET = '/run/redis_breached_passwords/socket'
BREACHED_PASSWORDS_REDIS_SOCKET = "/run/redis_breached_passwords/socket"
# Key where the bloom filter of password hashes from data breaches is stored # Key where the bloom filter of password hashes from data breaches is stored
BREACHED_PASSWORDS_BF_KEY = 'breached_passwords_bloom'
BREACHED_PASSWORDS_BF_KEY = "breached_passwords_bloom"
def is_breached_password(password: str) -> bool: def is_breached_password(password: str) -> bool:
"""Return whether the password is in the breached-passwords list.""" """Return whether the password is in the breached-passwords list."""
redis = StrictRedis(unix_socket_path=BREACHED_PASSWORDS_REDIS_SOCKET) redis = StrictRedis(unix_socket_path=BREACHED_PASSWORDS_REDIS_SOCKET)
hashed = sha1(password.encode('utf-8')).hexdigest()
hashed = sha1(password.encode("utf-8")).hexdigest()
try: try:
return bool(redis.execute_command(
'BF.EXISTS', BREACHED_PASSWORDS_BF_KEY, hashed))
return bool(
redis.execute_command("BF.EXISTS", BREACHED_PASSWORDS_BF_KEY, hashed)
)
except (ConnectionError, ResponseError): except (ConnectionError, ResponseError):
# server isn't running, bloom filter doesn't exist or the key is a # server isn't running, bloom filter doesn't exist or the key is a
# different data type # different data type

95
tildes/tildes/lib/ratelimit.py

@ -25,18 +25,17 @@ class RateLimitResult:
""" """
def __init__( def __init__(
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
) -> None: ) -> None:
"""Initialize a RateLimitResult.""" """Initialize a RateLimitResult."""
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
if is_allowed and time_until_retry is not None: if is_allowed and time_until_retry is not None:
raise ValueError(
'time_until_retry must be None if is_allowed is True')
raise ValueError("time_until_retry must be None if is_allowed is True")
self.is_allowed = is_allowed self.is_allowed = is_allowed
self.total_limit = total_limit self.total_limit = total_limit
@ -58,7 +57,7 @@ class RateLimitResult:
) )
@classmethod @classmethod
def unlimited_result(cls) -> 'RateLimitResult':
def unlimited_result(cls) -> "RateLimitResult":
"""Return a "blank" result representing an unlimited action.""" """Return a "blank" result representing an unlimited action."""
return cls( return cls(
is_allowed=True, is_allowed=True,
@ -68,7 +67,7 @@ class RateLimitResult:
) )
@classmethod @classmethod
def from_redis_cell_result(cls, result: List[int]) -> 'RateLimitResult':
def from_redis_cell_result(cls, result: List[int]) -> "RateLimitResult":
"""Convert the response from CL.THROTTLE command to a RateLimitResult. """Convert the response from CL.THROTTLE command to a RateLimitResult.
CL.THROTTLE responds with an array of 5 integers: CL.THROTTLE responds with an array of 5 integers:
@ -98,10 +97,7 @@ class RateLimitResult:
) )
@classmethod @classmethod
def merged_result(
cls,
results: Sequence['RateLimitResult'],
) -> 'RateLimitResult':
def merged_result(cls, results: Sequence["RateLimitResult"]) -> "RateLimitResult":
"""Merge any number of RateLimitResults into a single result. """Merge any number of RateLimitResults into a single result.
Basically, the merged result should be the "most restrictive" Basically, the merged result should be the "most restrictive"
@ -125,7 +121,8 @@ class RateLimitResult:
time_until_retry = None time_until_retry = None
else: else:
time_until_retry = max( time_until_retry = max(
r.time_until_retry for r in results if r.time_until_retry)
r.time_until_retry for r in results if r.time_until_retry
)
return cls( return cls(
is_allowed=all(r.is_allowed for r in results), is_allowed=all(r.is_allowed for r in results),
@ -140,18 +137,18 @@ class RateLimitResult:
# Retry-After: seconds the client should wait until retrying # Retry-After: seconds the client should wait until retrying
if self.time_until_retry: if self.time_until_retry:
retry_seconds = int(self.time_until_retry.total_seconds()) retry_seconds = int(self.time_until_retry.total_seconds())
response.headers['Retry-After'] = str(retry_seconds)
response.headers["Retry-After"] = str(retry_seconds)
# X-RateLimit-Limit: the total action limit (including used) # X-RateLimit-Limit: the total action limit (including used)
response.headers['X-RateLimit-Limit'] = str(self.total_limit)
response.headers["X-RateLimit-Limit"] = str(self.total_limit)
# X-RateLimit-Remaining: remaining actions before client hits the limit # X-RateLimit-Remaining: remaining actions before client hits the limit
response.headers['X-RateLimit-Remaining'] = str(self.remaining_limit)
response.headers["X-RateLimit-Remaining"] = str(self.remaining_limit)
# X-RateLimit-Reset: epoch timestamp when limit will be back to full # X-RateLimit-Reset: epoch timestamp when limit will be back to full
reset_time = utc_now() + self.time_until_max reset_time = utc_now() + self.time_until_max
reset_timestamp = int(reset_time.timestamp()) reset_timestamp = int(reset_time.timestamp())
response.headers['X-RateLimit-Reset'] = str(reset_timestamp)
response.headers["X-RateLimit-Reset"] = str(reset_timestamp)
return response return response
@ -165,14 +162,14 @@ class RateLimitedAction:
""" """
def __init__( def __init__(
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
) -> None: ) -> None:
"""Initialize the limits on a particular action. """Initialize the limits on a particular action.
@ -187,10 +184,10 @@ class RateLimitedAction:
""" """
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
if max_burst and not 1 <= max_burst <= limit: if max_burst and not 1 <= max_burst <= limit:
raise ValueError('max_burst must be at least 1 and <= limit')
raise ValueError("max_burst must be at least 1 and <= limit")
if not (by_user or by_ip): if not (by_user or by_ip):
raise ValueError('At least one of by_user or by_ip must be True')
raise ValueError("At least one of by_user or by_ip must be True")
self.name = name self.name = name
self.period = period self.period = period
@ -213,7 +210,7 @@ class RateLimitedAction:
def redis(self) -> StrictRedis: def redis(self) -> StrictRedis:
"""Return the redis connection.""" """Return the redis connection."""
if not self._redis: if not self._redis:
raise RateLimitError('No redis connection set')
raise RateLimitError("No redis connection set")
return self._redis return self._redis
@ -224,19 +221,14 @@ class RateLimitedAction:
def _build_redis_key(self, by_type: str, value: Any) -> str: def _build_redis_key(self, by_type: str, value: Any) -> str:
"""Build the Redis key where this rate limit is maintained.""" """Build the Redis key where this rate limit is maintained."""
parts = [
'ratelimit',
self.name,
by_type,
str(value),
]
parts = ["ratelimit", self.name, by_type, str(value)]
return ':'.join(parts)
return ":".join(parts)
def _call_redis_command(self, key: str) -> List[int]: def _call_redis_command(self, key: str) -> List[int]:
"""Call the redis-cell CL.THROTTLE command for this action.""" """Call the redis-cell CL.THROTTLE command for this action."""
return self.redis.execute_command( return self.redis.execute_command(
'CL.THROTTLE',
"CL.THROTTLE",
key, key,
self.max_burst - 1, self.max_burst - 1,
self.limit, self.limit,
@ -246,10 +238,9 @@ class RateLimitedAction:
def check_for_user_id(self, user_id: int) -> RateLimitResult: def check_for_user_id(self, user_id: int) -> RateLimitResult:
"""Check whether a particular user_id can perform this action.""" """Check whether a particular user_id can perform this action."""
if not self.by_user: if not self.by_user:
raise RateLimitError(
'check_for_user_id called on non-user-limited action')
raise RateLimitError("check_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
result = self._call_redis_command(key) result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result) return RateLimitResult.from_redis_cell_result(result)
@ -257,22 +248,20 @@ class RateLimitedAction:
def reset_for_user_id(self, user_id: int) -> None: def reset_for_user_id(self, user_id: int) -> None:
"""Reset the ratelimit on this action for a particular user_id.""" """Reset the ratelimit on this action for a particular user_id."""
if not self.by_user: if not self.by_user:
raise RateLimitError(
'reset_for_user_id called on non-user-limited action')
raise RateLimitError("reset_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
self.redis.delete(key) self.redis.delete(key)
def check_for_ip(self, ip_str: str) -> RateLimitResult: def check_for_ip(self, ip_str: str) -> RateLimitResult:
"""Check whether a particular IP can perform this action.""" """Check whether a particular IP can perform this action."""
if not self.by_ip: if not self.by_ip:
raise RateLimitError(
'check_for_ip called on non-IP-limited action')
raise RateLimitError("check_for_ip called on non-IP-limited action")
# check if ip_str is a valid address, will ValueError if not # check if ip_str is a valid address, will ValueError if not
ip_address(ip_str) ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
result = self._call_redis_command(key) result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result) return RateLimitResult.from_redis_cell_result(result)
@ -280,23 +269,21 @@ class RateLimitedAction:
def reset_for_ip(self, ip_str: str) -> None: def reset_for_ip(self, ip_str: str) -> None:
"""Reset the ratelimit on this action for a particular IP.""" """Reset the ratelimit on this action for a particular IP."""
if not self.by_ip: if not self.by_ip:
raise RateLimitError(
'reset_for_ip called on non-user-limited action')
raise RateLimitError("reset_for_ip called on non-user-limited action")
# check if ip_str is a valid address, will ValueError if not # check if ip_str is a valid address, will ValueError if not
ip_address(ip_str) ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
self.redis.delete(key) self.redis.delete(key)
# the actual list of actions with rate-limit restrictions # the actual list of actions with rate-limit restrictions
# each action must have a unique name to prevent key collisions # each action must have a unique name to prevent key collisions
_RATE_LIMITED_ACTIONS = ( _RATE_LIMITED_ACTIONS = (
RateLimitedAction('login', timedelta(hours=1), 20),
RateLimitedAction('register', timedelta(hours=1), 50),
RateLimitedAction("login", timedelta(hours=1), 20),
RateLimitedAction("register", timedelta(hours=1), 50),
) )
# (public) dict to be able to look up the actions by name # (public) dict to be able to look up the actions by name
RATE_LIMITED_ACTIONS = {
action.name: action for action in _RATE_LIMITED_ACTIONS}
RATE_LIMITED_ACTIONS = {action.name: action for action in _RATE_LIMITED_ACTIONS}

45
tildes/tildes/lib/string.py

@ -20,16 +20,16 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
slug = original.lower() slug = original.lower()
# remove apostrophes so contractions don't get broken up by underscores # remove apostrophes so contractions don't get broken up by underscores
slug = re.sub("['’]", '', slug)
slug = re.sub("['’]", "", slug)
# replace all remaining non-word characters with underscores # replace all remaining non-word characters with underscores
slug = re.sub(r'\W+', '_', slug)
slug = re.sub(r"\W+", "_", slug)
# remove any consecutive underscores # remove any consecutive underscores
slug = re.sub('_{2,}', '_', slug)
slug = re.sub("_{2,}", "_", slug)
# remove "hanging" underscores on the start and/or end # remove "hanging" underscores on the start and/or end
slug = slug.strip('_')
slug = slug.strip("_")
# url-encode the slug # url-encode the slug
encoded_slug = quote(slug) encoded_slug = quote(slug)
@ -42,7 +42,7 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
# Truncating a url-encoded slug can be tricky if there are any multi-byte # Truncating a url-encoded slug can be tricky if there are any multi-byte
# unicode characters, since the %-encoded forms of them can be quite long. # unicode characters, since the %-encoded forms of them can be quite long.
# Check to see if the slug looks like it might contain any of those. # Check to see if the slug looks like it might contain any of those.
maybe_multi_bytes = bool(re.search('%..%', encoded_slug))
maybe_multi_bytes = bool(re.search("%..%", encoded_slug))
# if that matched, we need to take a more complicated approach # if that matched, we need to take a more complicated approach
if maybe_multi_bytes: if maybe_multi_bytes:
@ -50,10 +50,7 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
# simple truncate - break at underscore if possible, no overflow string # simple truncate - break at underscore if possible, no overflow string
return truncate_string( return truncate_string(
encoded_slug,
max_length,
truncate_at_chars='_',
overflow_str=None,
encoded_slug, max_length, truncate_at_chars="_", overflow_str=None
) )
@ -62,7 +59,7 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
# instead of the normal method of truncating "backwards" from the end of # instead of the normal method of truncating "backwards" from the end of
# the string, build it up one encoded character at a time from the start # the string, build it up one encoded character at a time from the start
# until it's too long # until it's too long
encoded_slug = ''
encoded_slug = ""
for character in original: for character in original:
encoded_char = quote(character) encoded_char = quote(character)
@ -82,7 +79,7 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
# determining the word edges is not simple. # determining the word edges is not simple.
acceptable_truncation = 0.7 acceptable_truncation = 0.7
truncated_slug = truncate_string_at_char(encoded_slug, '_')
truncated_slug = truncate_string_at_char(encoded_slug, "_")
if len(truncated_slug) / len(encoded_slug) >= acceptable_truncation: if len(truncated_slug) / len(encoded_slug) >= acceptable_truncation:
return truncated_slug return truncated_slug
@ -91,10 +88,10 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
def truncate_string( def truncate_string(
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = '...',
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = "...",
) -> str: ) -> str:
"""Truncate a string to be no longer than a specified length. """Truncate a string to be no longer than a specified length.
@ -109,7 +106,7 @@ def truncate_string(
string will be kept. string will be kept.
""" """
if overflow_str is None: if overflow_str is None:
overflow_str = ''
overflow_str = ""
# no need to do anything if the string is already short enough # no need to do anything if the string is already short enough
if len(original) <= length: if len(original) <= length:
@ -117,7 +114,7 @@ def truncate_string(
# cut the string down to the max desired length (leaving space for the # cut the string down to the max desired length (leaving space for the
# overflow string if one is specified) # overflow string if one is specified)
truncated = original[:length - len(overflow_str)]
truncated = original[: length - len(overflow_str)]
# if we don't want to truncate at particular characters, we're done # if we don't want to truncate at particular characters, we're done
if not truncate_at_chars: if not truncate_at_chars:
@ -167,7 +164,7 @@ def simplify_string(original: str) -> str:
simplified = _sanitize_characters(original) simplified = _sanitize_characters(original)
# replace consecutive spaces with a single space # replace consecutive spaces with a single space
simplified = re.sub(r'\s{2,}', ' ', simplified)
simplified = re.sub(r"\s{2,}", " ", simplified)
# remove any remaining leading/trailing whitespace # remove any remaining leading/trailing whitespace
simplified = simplified.strip() simplified = simplified.strip()
@ -182,16 +179,16 @@ def _sanitize_characters(original: str) -> str:
for char in original: for char in original:
category = unicodedata.category(char) category = unicodedata.category(char)
if category.startswith('Z'):
if category.startswith("Z"):
# "separator" chars - replace with a normal space # "separator" chars - replace with a normal space
final_characters.append(' ')
elif category.startswith('C'):
final_characters.append(" ")
elif category.startswith("C"):
# "other" chars (control, formatting, etc.) - filter them out # "other" chars (control, formatting, etc.) - filter them out
# except for newlines, which are replaced with normal spaces # except for newlines, which are replaced with normal spaces
if char == '\n':
final_characters.append(' ')
if char == "\n":
final_characters.append(" ")
else: else:
# any other type of character, just keep it # any other type of character, just keep it
final_characters.append(char) final_characters.append(char)
return ''.join(final_characters)
return "".join(final_characters)

4
tildes/tildes/lib/url.py

@ -8,9 +8,9 @@ def get_domain_from_url(url: str, strip_www: bool = True) -> str:
domain = urlparse(url).netloc domain = urlparse(url).netloc
if not domain: if not domain:
raise ValueError('Invalid url or domain could not be determined')
raise ValueError("Invalid url or domain could not be determined")
if strip_www and domain.startswith('www.'):
if strip_www and domain.startswith("www."):
domain = domain[4:] domain = domain[4:]
return domain return domain

60
tildes/tildes/metrics.py

@ -11,50 +11,32 @@ from prometheus_client.core import _LabelWrapper
_COUNTERS = { _COUNTERS = {
'votes': Counter(
'tildes_votes_total',
'Votes',
labelnames=['target_type'],
),
'comments': Counter('tildes_comments_total', 'Comments'),
'invite_code_failures': Counter(
'tildes_invite_code_failures_total',
'Invite Code Failures',
),
'logins': Counter('tildes_logins_total', 'Login Attempts'),
'login_failures': Counter(
'tildes_login_failures_total',
'Login Failures',
),
'messages': Counter(
'tildes_messages_total',
'Messages',
labelnames=['type'],
),
'registrations': Counter(
'tildes_registrations_total',
'User Registrations',
),
'topics': Counter('tildes_topics_total', 'Topics', labelnames=['type']),
'subscriptions': Counter('tildes_subscriptions_total', 'Subscriptions'),
'unsubscriptions': Counter(
'tildes_unsubscriptions_total',
'Unsubscriptions',
"votes": Counter("tildes_votes_total", "Votes", labelnames=["target_type"]),
"comments": Counter("tildes_comments_total", "Comments"),
"invite_code_failures": Counter(
"tildes_invite_code_failures_total", "Invite Code Failures"
), ),
"logins": Counter("tildes_logins_total", "Login Attempts"),
"login_failures": Counter("tildes_login_failures_total", "Login Failures"),
"messages": Counter("tildes_messages_total", "Messages", labelnames=["type"]),
"registrations": Counter("tildes_registrations_total", "User Registrations"),
"topics": Counter("tildes_topics_total", "Topics", labelnames=["type"]),
"subscriptions": Counter("tildes_subscriptions_total", "Subscriptions"),
"unsubscriptions": Counter("tildes_unsubscriptions_total", "Unsubscriptions"),
} }
_HISTOGRAMS = { _HISTOGRAMS = {
'markdown_processing': Histogram(
'tildes_markdown_processing_seconds',
'Markdown processing',
"markdown_processing": Histogram(
"tildes_markdown_processing_seconds",
"Markdown processing",
buckets=[.001, .0025, .005, .01, 0.025, .05, .1, .5, 1.0], buckets=[.001, .0025, .005, .01, 0.025, .05, .1, .5, 1.0],
), ),
'comment_tree_sorting': Histogram(
'tildes_comment_tree_sorting_seconds',
'Comment tree sorting time',
labelnames=['num_comments_range', 'order'],
"comment_tree_sorting": Histogram(
"tildes_comment_tree_sorting_seconds",
"Comment tree sorting time",
labelnames=["num_comments_range", "order"],
buckets=[.00001, .0001, .001, .01, .05, .1, .5, 1.0], buckets=[.00001, .0001, .001, .01, .05, .1, .5, 1.0],
)
),
} }
@ -63,7 +45,7 @@ def incr_counter(name: str, amount: int = 1, **labels: str) -> None:
try: try:
counter = _COUNTERS[name] counter = _COUNTERS[name]
except KeyError: except KeyError:
raise ValueError('Invalid counter name')
raise ValueError("Invalid counter name")
if isinstance(counter, _LabelWrapper): if isinstance(counter, _LabelWrapper):
counter = counter.labels(**labels) counter = counter.labels(**labels)
@ -76,7 +58,7 @@ def get_histogram(name: str, **labels: str) -> Histogram:
try: try:
hist = _HISTOGRAMS[name] hist = _HISTOGRAMS[name]
except KeyError: except KeyError:
raise ValueError('Invalid histogram name')
raise ValueError("Invalid histogram name")
if isinstance(hist, _LabelWrapper): if isinstance(hist, _LabelWrapper):
hist = hist.labels(**labels) hist = hist.labels(**labels)

97
tildes/tildes/models/comment/comment.py

@ -5,14 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import deferred, relationship from sqlalchemy.orm import deferred, relationship
from sqlalchemy.sql.expression import text from sqlalchemy.sql.expression import text
@ -60,47 +53,40 @@ class Comment(DatabaseModel):
schema_class = CommentSchema schema_class = CommentSchema
__tablename__ = 'comments'
__tablename__ = "comments"
comment_id: int = Column(Integer, primary_key=True) comment_id: int = Column(Integer, primary_key=True)
topic_id: int = Column( topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
index=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, index=True
) )
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
parent_comment_id: Optional[int] = Column( parent_comment_id: Optional[int] = Column(
Integer,
ForeignKey('comments.comment_id'),
index=True,
Integer, ForeignKey("comments.comment_id"), index=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
is_deleted: bool = Column( is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column(Boolean, nullable=False, server_default='false')
is_removed: bool = Column(Boolean, nullable=False, server_default="false")
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
_markdown: str = deferred(Column('markdown', Text, nullable=False))
_markdown: str = deferred(Column("markdown", Text, nullable=False))
rendered_html: str = Column(Text, nullable=False) rendered_html: str = Column(Text, nullable=False)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
user: User = relationship('User', lazy=False, innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
parent_comment: Optional['Comment'] = relationship(
'Comment', uselist=False, remote_side=[comment_id])
user: User = relationship("User", lazy=False, innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
parent_comment: Optional["Comment"] = relationship(
"Comment", uselist=False, remote_side=[comment_id]
)
@hybrid_property @hybrid_property
def markdown(self) -> str: def markdown(self) -> str:
@ -117,20 +103,19 @@ class Comment(DatabaseModel):
self._markdown = new_markdown self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown) self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now() self.last_edited_time = utc_now()
def __repr__(self) -> str: def __repr__(self) -> str:
"""Display the comment's ID as its repr format.""" """Display the comment's ID as its repr format."""
return f'<Comment ({self.comment_id})>'
return f"<Comment ({self.comment_id})>"
def __init__( def __init__(
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional['Comment'] = None,
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional["Comment"] = None,
) -> None: ) -> None:
"""Create a new comment.""" """Create a new comment."""
self.topic = topic self.topic = topic
@ -142,7 +127,7 @@ class Comment(DatabaseModel):
self.markdown = markdown self.markdown = markdown
incr_counter('comments')
incr_counter("comments")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]: def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL.""" """Pyramid security ACL."""
@ -156,49 +141,49 @@ class Comment(DatabaseModel):
# - removed comments can only be viewed by admins and the author # - removed comments can only be viewed by admins and the author
# - otherwise, everyone can view # - otherwise, everyone can view
if self.is_removed: if self.is_removed:
acl.append((Allow, 'admin', 'view'))
acl.append((Allow, self.user_id, 'view'))
acl.append((Deny, Everyone, 'view'))
acl.append((Allow, "admin", "view"))
acl.append((Allow, self.user_id, "view"))
acl.append((Deny, Everyone, "view"))
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# vote: # vote:
# - removed comments can't be voted on by anyone # - removed comments can't be voted on by anyone
# - otherwise, logged-in users except the author can vote # - otherwise, logged-in users except the author can vote
if self.is_removed: if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# tag: # tag:
# - temporary: nobody can tag comments # - temporary: nobody can tag comments
acl.append((Deny, Everyone, 'tag'))
acl.append((Deny, Everyone, "tag"))
# reply: # reply:
# - removed comments can't be replied to by anyone # - removed comments can't be replied to by anyone
# - if the topic is locked, only admins can reply # - if the topic is locked, only admins can reply
# - otherwise, logged-in users can reply # - otherwise, logged-in users can reply
if self.is_removed: if self.is_removed:
acl.append((Deny, Everyone, 'reply'))
acl.append((Deny, Everyone, "reply"))
if self.topic.is_locked: if self.topic.is_locked:
acl.append((Allow, 'admin', 'reply'))
acl.append((Deny, Everyone, 'reply'))
acl.append((Allow, "admin", "reply"))
acl.append((Deny, Everyone, "reply"))
acl.append((Allow, Authenticated, 'reply'))
acl.append((Allow, Authenticated, "reply"))
# edit: # edit:
# - only the author can edit # - only the author can edit
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete: # delete:
# - only the author can delete # - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# mark_read: # mark_read:
# - logged-in users can mark comments read # - logged-in users can mark comments read
acl.append((Allow, Authenticated, 'mark_read'))
acl.append((Allow, Authenticated, "mark_read"))
acl.append(DENY_ALL) acl.append(DENY_ALL)
@ -220,7 +205,7 @@ class Comment(DatabaseModel):
@property @property
def permalink(self) -> str: def permalink(self) -> str:
"""Return the permalink for this comment.""" """Return the permalink for this comment."""
return f'{self.topic.permalink}#comment-{self.comment_id36}'
return f"{self.topic.permalink}#comment-{self.comment_id36}"
@property @property
def parent_comment_permalink(self) -> str: def parent_comment_permalink(self) -> str:
@ -228,7 +213,7 @@ class Comment(DatabaseModel):
if not self.parent_comment_id: if not self.parent_comment_id:
raise AttributeError raise AttributeError
return f'{self.topic.permalink}#comment-{self.parent_comment_id36}'
return f"{self.topic.permalink}#comment-{self.parent_comment_id36}"
@property @property
def tag_counts(self) -> Counter: def tag_counts(self) -> Counter:

71
tildes/tildes/models/comment/comment_notification.py

@ -29,38 +29,27 @@ class CommentNotification(DatabaseModel):
decrement num_unread_notifications for the relevant user. decrement num_unread_notifications for the relevant user.
""" """
__tablename__ = 'comment_notifications'
__tablename__ = "comment_notifications"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
comment_id: int = Column( comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
) )
notification_type: CommentNotificationType = Column( notification_type: CommentNotificationType = Column(
ENUM(CommentNotificationType), nullable=False)
ENUM(CommentNotificationType), nullable=False
)
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
) )
is_unread: bool = Column(
Boolean, nullable=False, server_default='true', index=True)
is_unread: bool = Column(Boolean, nullable=False, server_default="true", index=True)
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__( def __init__(
self,
user: User,
comment: Comment,
notification_type: CommentNotificationType,
self, user: User, comment: Comment, notification_type: CommentNotificationType
) -> None: ) -> None:
"""Create a new notification for a user from a comment.""" """Create a new notification for a user from a comment."""
self.user = user self.user = user
@ -70,7 +59,7 @@ class CommentNotification(DatabaseModel):
def __acl__(self) -> Sequence[Tuple[str, Any, str]]: def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL.""" """Pyramid security ACL."""
acl = [] acl = []
acl.append((Allow, self.user_id, 'mark_read'))
acl.append((Allow, self.user_id, "mark_read"))
acl.append(DENY_ALL) acl.append(DENY_ALL)
return acl return acl
@ -91,17 +80,12 @@ class CommentNotification(DatabaseModel):
@classmethod @classmethod
def get_mentions_for_comment( def get_mentions_for_comment(
cls,
db_session: Session,
comment: Comment,
) -> List['CommentNotification']:
cls, db_session: Session, comment: Comment
) -> List["CommentNotification"]:
"""Get a list of notifications for user mentions in the comment.""" """Get a list of notifications for user mentions in the comment."""
notifications = [] notifications = []
raw_names = re.findall(
LinkifyFilter.USERNAME_REFERENCE_REGEX,
comment.markdown,
)
raw_names = re.findall(LinkifyFilter.USERNAME_REFERENCE_REGEX, comment.markdown)
users_to_mention = ( users_to_mention = (
db_session.query(User) db_session.query(User)
.filter(User.username.in_(raw_names)) # type: ignore .filter(User.username.in_(raw_names)) # type: ignore
@ -124,17 +108,18 @@ class CommentNotification(DatabaseModel):
continue continue
mention_notification = cls( mention_notification = cls(
user, comment, CommentNotificationType.USER_MENTION)
user, comment, CommentNotificationType.USER_MENTION
)
notifications.append(mention_notification) notifications.append(mention_notification)
return notifications return notifications
@staticmethod @staticmethod
def prevent_duplicate_notifications( def prevent_duplicate_notifications(
db_session: Session,
comment: Comment,
new_notifications: List['CommentNotification'],
) -> Tuple[List['CommentNotification'], List['CommentNotification']]:
db_session: Session,
comment: Comment,
new_notifications: List["CommentNotification"],
) -> Tuple[List["CommentNotification"], List["CommentNotification"]]:
"""Filter new notifications for edited comments. """Filter new notifications for edited comments.
Protect against sending a notification for the same comment to Protect against sending a notification for the same comment to
@ -149,13 +134,13 @@ class CommentNotification(DatabaseModel):
that need to be added, as they're new. that need to be added, as they're new.
""" """
previous_notifications = ( previous_notifications = (
db_session
.query(CommentNotification)
db_session.query(CommentNotification)
.filter( .filter(
CommentNotification.comment_id == comment.comment_id, CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
).all()
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
.all()
) )
new_mention_user_ids = [ new_mention_user_ids = [
@ -167,12 +152,14 @@ class CommentNotification(DatabaseModel):
] ]
to_delete = [ to_delete = [
notification for notification in previous_notifications
notification
for notification in previous_notifications
if notification.user.user_id not in new_mention_user_ids if notification.user.user_id not in new_mention_user_ids
] ]
to_add = [ to_add = [
notification for notification in new_notifications
notification
for notification in new_notifications
if notification.user.user_id not in previous_mention_user_ids if notification.user.user_id not in previous_mention_user_ids
] ]

10
tildes/tildes/models/comment/comment_notification_query.py

@ -17,7 +17,7 @@ class CommentNotificationQuery(ModelQuery):
"""Initialize a CommentNotificationQuery for the request.""" """Initialize a CommentNotificationQuery for the request."""
super().__init__(CommentNotification, request) super().__init__(CommentNotification, request)
def _attach_extra_data(self) -> 'CommentNotificationQuery':
def _attach_extra_data(self) -> "CommentNotificationQuery":
"""Attach the user's comment votes to the query.""" """Attach the user's comment votes to the query."""
vote_subquery = ( vote_subquery = (
self.request.query(CommentVote) self.request.query(CommentVote)
@ -26,16 +26,16 @@ class CommentNotificationQuery(ModelQuery):
CommentVote.user == self.request.user, CommentVote.user == self.request.user,
) )
.exists() .exists()
.label('user_voted')
.label("user_voted")
) )
return self.add_columns(vote_subquery) return self.add_columns(vote_subquery)
def join_all_relationships(self) -> 'CommentNotificationQuery':
def join_all_relationships(self) -> "CommentNotificationQuery":
"""Eagerly join the comment, topic, and group to the notification.""" """Eagerly join the comment, topic, and group to the notification."""
self = self.options( self = self.options(
joinedload(CommentNotification.comment) joinedload(CommentNotification.comment)
.joinedload('topic')
.joinedload('group')
.joinedload("topic")
.joinedload("group")
) )
return self return self

6
tildes/tildes/models/comment/comment_query.py

@ -21,14 +21,14 @@ class CommentQuery(PaginatedQuery):
""" """
super().__init__(Comment, request) super().__init__(Comment, request)
def _attach_extra_data(self) -> 'CommentQuery':
def _attach_extra_data(self) -> "CommentQuery":
"""Attach the extra user data to the query.""" """Attach the extra user data to the query."""
if not self.request.user: if not self.request.user:
return self return self
return self._attach_vote_data() return self._attach_vote_data()
def _attach_vote_data(self) -> 'CommentQuery':
def _attach_vote_data(self) -> "CommentQuery":
"""Add a subquery to include whether the user has voted.""" """Add a subquery to include whether the user has voted."""
vote_subquery = ( vote_subquery = (
self.request.query(CommentVote) self.request.query(CommentVote)
@ -37,7 +37,7 @@ class CommentQuery(PaginatedQuery):
CommentVote.user_id == self.request.user.user_id, CommentVote.user_id == self.request.user.user_id,
) )
.exists() .exists()
.label('user_voted')
.label("user_voted")
) )
return self.add_columns(vote_subquery) return self.add_columns(vote_subquery)

29
tildes/tildes/models/comment/comment_tag.py

@ -16,37 +16,24 @@ from .comment import Comment
class CommentTag(DatabaseModel): class CommentTag(DatabaseModel):
"""Model for the tags attached to comments by users.""" """Model for the tags attached to comments by users."""
__tablename__ = 'comment_tags'
__tablename__ = "comment_tags"
comment_id: int = Column( comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
) )
tag: CommentTagOption = Column( tag: CommentTagOption = Column(
ENUM(CommentTagOption), nullable=False, primary_key=True)
ENUM(CommentTagOption), nullable=False, primary_key=True
)
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
) )
comment: Comment = relationship(
Comment, backref=backref('tags', lazy=False))
comment: Comment = relationship(Comment, backref=backref("tags", lazy=False))
def __init__(
self,
comment: Comment,
user: User,
tag: CommentTagOption,
) -> None:
def __init__(self, comment: Comment, user: User, tag: CommentTagOption) -> None:
"""Add a new tag to a comment.""" """Add a new tag to a comment."""
self.comment_id = comment.comment_id self.comment_id = comment.comment_id
self.user_id = user.user_id self.user_id = user.user_id

23
tildes/tildes/models/comment/comment_tree.py

@ -18,11 +18,7 @@ class CommentTree:
descendants (if not, it can be pruned from the tree) descendants (if not, it can be pruned from the tree)
""" """
def __init__(
self,
comments: Sequence[Comment],
sort: CommentSortOption,
) -> None:
def __init__(self, comments: Sequence[Comment], sort: CommentSortOption) -> None:
"""Create a sorted CommentTree from a flat list of Comments.""" """Create a sorted CommentTree from a flat list of Comments."""
self.tree: List[Comment] = [] self.tree: List[Comment] = []
self.sort = sort self.sort = sort
@ -76,10 +72,7 @@ class CommentTree:
self.tree.append(comment) self.tree.append(comment)
@staticmethod @staticmethod
def _sort_tree(
tree: List[Comment],
sort: CommentSortOption,
) -> List[Comment]:
def _sort_tree(tree: List[Comment], sort: CommentSortOption) -> List[Comment]:
"""Sort the tree by the desired ordering (recursively). """Sort the tree by the desired ordering (recursively).
Because Python's sorted() function is stable, the ordering of any Because Python's sorted() function is stable, the ordering of any
@ -149,18 +142,18 @@ class CommentTree:
# make an "order of magnitude" label based on the number of comments # make an "order of magnitude" label based on the number of comments
if num_comments == 0: if num_comments == 0:
raise ValueError('Attempting to time an empty comment tree sort')
raise ValueError("Attempting to time an empty comment tree sort")
if num_comments < 10: if num_comments < 10:
num_comments_range = '1 - 9'
num_comments_range = "1 - 9"
elif num_comments < 100: elif num_comments < 100:
num_comments_range = '10 - 99'
num_comments_range = "10 - 99"
elif num_comments < 1000: elif num_comments < 1000:
num_comments_range = '100 - 999'
num_comments_range = "100 - 999"
else: else:
num_comments_range = '1000+'
num_comments_range = "1000+"
return get_histogram( return get_histogram(
'comment_tree_sorting',
"comment_tree_sorting",
num_comments_range=num_comments_range, num_comments_range=num_comments_range,
order=self.sort.name, order=self.sort.name,
) )

20
tildes/tildes/models/comment/comment_vote.py

@ -21,33 +21,27 @@ class CommentVote(DatabaseModel):
column for the relevant comment. column for the relevant comment.
""" """
__tablename__ = 'comment_votes'
__tablename__ = "comment_votes"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
comment_id: int = Column( comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__(self, user: User, comment: Comment) -> None: def __init__(self, user: User, comment: Comment) -> None:
"""Create a new vote on a comment.""" """Create a new vote on a comment."""
self.user = user self.user = user
self.comment = comment self.comment = comment
incr_counter('votes', target_type='comment')
incr_counter("votes", target_type="comment")

34
tildes/tildes/models/database_model.py

@ -11,36 +11,31 @@ from sqlalchemy.schema import MetaData
from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import Table
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
# SQLAlchemy naming convention for constraints and indexes # SQLAlchemy naming convention for constraints and indexes
NAMING_CONVENTION = { NAMING_CONVENTION = {
'pk': 'pk_%(table_name)s',
'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s',
'ix': 'ix_%(table_name)s_%(column_0_name)s',
'ck': 'ck_%(table_name)s_%(constraint_name)s',
'uq': 'uq_%(table_name)s_%(column_0_name)s',
"pk": "pk_%(table_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"ix": "ix_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
} }
def attach_set_listener( def attach_set_listener(
class_: Type['DatabaseModelBase'],
attribute: str,
instance: 'DatabaseModelBase',
class_: Type["DatabaseModelBase"], attribute: str, instance: "DatabaseModelBase"
) -> None: ) -> None:
"""Attach the SQLAlchemy ORM "set" attribute listener.""" """Attach the SQLAlchemy ORM "set" attribute listener."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def set_handler( def set_handler(
target: 'DatabaseModelBase',
value: Any,
oldvalue: Any,
initiator: Any,
target: "DatabaseModelBase", value: Any, oldvalue: Any, initiator: Any
) -> Any: ) -> Any:
"""Handle an SQLAlchemy ORM "set" attribute event.""" """Handle an SQLAlchemy ORM "set" attribute event."""
# pylint: disable=protected-access # pylint: disable=protected-access
return target._validate_new_value(attribute, value) return target._validate_new_value(attribute, value)
event.listen(instance, 'set', set_handler, retval=True)
event.listen(instance, "set", set_handler, retval=True)
class DatabaseModelBase: class DatabaseModelBase:
@ -71,8 +66,7 @@ class DatabaseModelBase:
key columns used in __eq__, as recommended in the Python documentation. key columns used in __eq__, as recommended in the Python documentation.
""" """
primary_key_values = tuple( primary_key_values = tuple(
getattr(self, column.name)
for column in self.__table__.primary_key
getattr(self, column.name) for column in self.__table__.primary_key
) )
return hash(primary_key_values) return hash(primary_key_values)
@ -82,7 +76,7 @@ class DatabaseModelBase:
if not self.schema_class: if not self.schema_class:
raise AttributeError raise AttributeError
if not hasattr(self, '_schema'):
if not hasattr(self, "_schema"):
self._schema = self.schema_class(partial=True) # noqa self._schema = self.schema_class(partial=True) # noqa
return self._schema return self._schema
@ -112,7 +106,7 @@ class DatabaseModelBase:
# set starts with an underscore, assume that it's due to being set up # set starts with an underscore, assume that it's due to being set up
# as a hybrid property, and remove the underscore prefix when looking # as a hybrid property, and remove the underscore prefix when looking
# for a field to validate against. # for a field to validate against.
if attribute.startswith('_'):
if attribute.startswith("_"):
attribute = attribute[1:] attribute = attribute[1:]
field = self.schema.fields.get(attribute) field = self.schema.fields.get(attribute)
@ -126,13 +120,13 @@ class DatabaseModelBase:
DatabaseModel = declarative_base( # pylint: disable=invalid-name DatabaseModel = declarative_base( # pylint: disable=invalid-name
cls=DatabaseModelBase, cls=DatabaseModelBase,
name='DatabaseModel',
name="DatabaseModel",
metadata=MetaData(naming_convention=NAMING_CONVENTION), metadata=MetaData(naming_convention=NAMING_CONVENTION),
) )
# attach the listener for SQLAlchemy ORM attribute "set" events to all models # attach the listener for SQLAlchemy ORM attribute "set" events to all models
event.listen(DatabaseModel, 'attribute_instrument', attach_set_listener)
event.listen(DatabaseModel, "attribute_instrument", attach_set_listener)
# associate JSONB columns with MutableDict so value changes are detected # associate JSONB columns with MutableDict so value changes are detected
MutableDict.associate_with(JSONB) MutableDict.associate_with(JSONB)

44
tildes/tildes/models/group/group.py

@ -4,15 +4,7 @@ from datetime import datetime
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
Index,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, CheckConstraint, Column, Index, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text from sqlalchemy.sql.expression import text
from sqlalchemy_utils import Ltree, LtreeType from sqlalchemy_utils import Ltree, LtreeType
@ -31,7 +23,7 @@ class Group(DatabaseModel):
schema_class = GroupSchema schema_class = GroupSchema
__tablename__ = 'groups'
__tablename__ = "groups"
group_id: int = Column(Integer, primary_key=True) group_id: int = Column(Integer, primary_key=True)
path: Ltree = Column(LtreeType, nullable=False, index=True, unique=True) path: Ltree = Column(LtreeType, nullable=False, index=True, unique=True)
@ -39,36 +31,34 @@ class Group(DatabaseModel):
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
short_description: Optional[str] = Column( short_description: Optional[str] = Column(
Text, Text,
CheckConstraint( CheckConstraint(
f'LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}',
name='short_description_length',
)
f"LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}",
name="short_description_length",
),
) )
num_subscriptions: int = Column(
Integer, nullable=False, server_default='0')
num_subscriptions: int = Column(Integer, nullable=False, server_default="0")
is_admin_posting_only: bool = Column( is_admin_posting_only: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
# Create a GiST index on path as well as the btree one that will be created # Create a GiST index on path as well as the btree one that will be created
# by the index=True/unique=True keyword args to Column above. The GiST # by the index=True/unique=True keyword args to Column above. The GiST
# index supports additional operators for ltree queries: @>, <@, @, ~, ? # index supports additional operators for ltree queries: @>, <@, @, ~, ?
__table_args__ = (
Index('ix_groups_path_gist', path, postgresql_using='gist'),
)
__table_args__ = (Index("ix_groups_path_gist", path, postgresql_using="gist"),)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Display the group's path and ID as its repr format.""" """Display the group's path and ID as its repr format."""
return f'<Group {self.path} ({self.group_id})>'
return f"<Group {self.path} ({self.group_id})>"
def __str__(self) -> str: def __str__(self) -> str:
"""Use the group path for the string representation.""" """Use the group path for the string representation."""
return str(self.path) return str(self.path)
def __lt__(self, other: 'Group') -> bool:
def __lt__(self, other: "Group") -> bool:
"""Order groups by their string representation.""" """Order groups by their string representation."""
return str(self) < str(other) return str(self) < str(other)
@ -83,20 +73,20 @@ class Group(DatabaseModel):
# view: # view:
# - all groups can be viewed by everyone # - all groups can be viewed by everyone
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# subscribe: # subscribe:
# - all groups can be subscribed to by logged-in users # - all groups can be subscribed to by logged-in users
acl.append((Allow, Authenticated, 'subscribe'))
acl.append((Allow, Authenticated, "subscribe"))
# post_topic: # post_topic:
# - only admins can post in admin-posting-only groups # - only admins can post in admin-posting-only groups
# - otherwise, all logged-in users can post # - otherwise, all logged-in users can post
if self.is_admin_posting_only: if self.is_admin_posting_only:
acl.append((Allow, 'admin', 'post_topic'))
acl.append((Deny, Everyone, 'post_topic'))
acl.append((Allow, "admin", "post_topic"))
acl.append((Deny, Everyone, "post_topic"))
acl.append((Allow, Authenticated, 'post_topic'))
acl.append((Allow, Authenticated, "post_topic"))
acl.append(DENY_ALL) acl.append(DENY_ALL)

6
tildes/tildes/models/group/group_query.py

@ -21,14 +21,14 @@ class GroupQuery(ModelQuery):
""" """
super().__init__(Group, request) super().__init__(Group, request)
def _attach_extra_data(self) -> 'GroupQuery':
def _attach_extra_data(self) -> "GroupQuery":
"""Attach the extra user data to the query.""" """Attach the extra user data to the query."""
if not self.request.user: if not self.request.user:
return self return self
return self._attach_subscription_data() return self._attach_subscription_data()
def _attach_subscription_data(self) -> 'GroupQuery':
def _attach_subscription_data(self) -> "GroupQuery":
"""Add a subquery to include whether the user is subscribed.""" """Add a subquery to include whether the user is subscribed."""
subscription_subquery = ( subscription_subquery = (
self.request.query(GroupSubscription) self.request.query(GroupSubscription)
@ -37,7 +37,7 @@ class GroupQuery(ModelQuery):
GroupSubscription.user == self.request.user, GroupSubscription.user == self.request.user,
) )
.exists() .exists()
.label('user_subscribed')
.label("user_subscribed")
) )
return self.add_columns(subscription_subquery) return self.add_columns(subscription_subquery)

20
tildes/tildes/models/group/group_subscription.py

@ -21,33 +21,27 @@ class GroupSubscription(DatabaseModel):
num_subscriptions column for the relevant group. num_subscriptions column for the relevant group.
""" """
__tablename__ = 'group_subscriptions'
__tablename__ = "group_subscriptions"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
group_id: int = Column( group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
user: User = relationship('User', innerjoin=True, backref='subscriptions')
group: Group = relationship('Group', innerjoin=True, lazy=False)
user: User = relationship("User", innerjoin=True, backref="subscriptions")
group: Group = relationship("Group", innerjoin=True, lazy=False)
def __init__(self, user: User, group: Group) -> None: def __init__(self, user: User, group: Group) -> None:
"""Create a new subscription to a group.""" """Create a new subscription to a group."""
self.user = user self.user = user
self.group = group self.group = group
incr_counter('subscriptions')
incr_counter("subscriptions")

135
tildes/tildes/models/log/log.py

@ -3,15 +3,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pyramid.request import Request from pyramid.request import Request
from sqlalchemy import (
BigInteger,
Column,
event,
ForeignKey,
Integer,
Table,
TIMESTAMP,
)
from sqlalchemy import BigInteger, Column, event, ForeignKey, Integer, Table, TIMESTAMP
from sqlalchemy.dialects.postgresql import ENUM, INET, JSONB from sqlalchemy.dialects.postgresql import ENUM, INET, JSONB
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
@ -23,7 +15,7 @@ from tildes.models import DatabaseModel
from tildes.models.topic import Topic from tildes.models.topic import Topic
class BaseLog():
class BaseLog:
"""Mixin class with the shared columns/relationships for log classes.""" """Mixin class with the shared columns/relationships for log classes."""
@declared_attr @declared_attr
@ -34,7 +26,7 @@ class BaseLog():
@declared_attr @declared_attr
def user_id(self) -> Column: def user_id(self) -> Column:
"""Return the user_id column.""" """Return the user_id column."""
return Column(Integer, ForeignKey('users.user_id'), index=True)
return Column(Integer, ForeignKey("users.user_id"), index=True)
@declared_attr @declared_attr
def event_type(self) -> Column: def event_type(self) -> Column:
@ -53,7 +45,7 @@ class BaseLog():
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
@declared_attr @declared_attr
@ -64,21 +56,21 @@ class BaseLog():
@declared_attr @declared_attr
def user(self) -> Any: def user(self) -> Any:
"""Return the user relationship.""" """Return the user relationship."""
return relationship('User', lazy=False)
return relationship("User", lazy=False)
class Log(DatabaseModel, BaseLog): class Log(DatabaseModel, BaseLog):
"""Model for a basic log entry.""" """Model for a basic log entry."""
__tablename__ = 'log'
__tablename__ = "log"
INHERITED_TABLES = ['log_topics']
INHERITED_TABLES = ["log_topics"]
def __init__( def __init__(
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Create a new log entry. """Create a new log entry.
@ -97,19 +89,20 @@ class Log(DatabaseModel, BaseLog):
class LogTopic(DatabaseModel, BaseLog): class LogTopic(DatabaseModel, BaseLog):
"""Model for a log entry related to a specific topic.""" """Model for a log entry related to a specific topic."""
__tablename__ = 'log_topics'
__tablename__ = "log_topics"
topic_id: int = Column( topic_id: int = Column(
Integer, ForeignKey('topics.topic_id'), index=True, nullable=False)
Integer, ForeignKey("topics.topic_id"), index=True, nullable=False
)
topic: Topic = relationship('Topic')
topic: Topic = relationship("Topic")
def __init__( def __init__(
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Create a new log entry related to a specific topic.""" """Create a new log entry related to a specific topic."""
# pylint: disable=non-parent-init-called # pylint: disable=non-parent-init-called
@ -122,57 +115,55 @@ class LogTopic(DatabaseModel, BaseLog):
if self.event_type == LogEventType.TOPIC_TAG: if self.event_type == LogEventType.TOPIC_TAG:
return self._tag_event_description() return self._tag_event_description()
elif self.event_type == LogEventType.TOPIC_MOVE: elif self.event_type == LogEventType.TOPIC_MOVE:
old_group = self.info['old'] # noqa
new_group = self.info['new'] # noqa
return f'moved from ~{old_group} to ~{new_group}'
old_group = self.info["old"] # noqa
new_group = self.info["new"] # noqa
return f"moved from ~{old_group} to ~{new_group}"
elif self.event_type == LogEventType.TOPIC_LOCK: elif self.event_type == LogEventType.TOPIC_LOCK:
return 'locked comments'
return "locked comments"
elif self.event_type == LogEventType.TOPIC_UNLOCK: elif self.event_type == LogEventType.TOPIC_UNLOCK:
return 'unlocked comments'
return "unlocked comments"
elif self.event_type == LogEventType.TOPIC_TITLE_EDIT: elif self.event_type == LogEventType.TOPIC_TITLE_EDIT:
old_title = self.info['old'] # noqa
new_title = self.info['new'] # noqa
old_title = self.info["old"] # noqa
new_title = self.info["new"] # noqa
return f'changed title from "{old_title}" to "{new_title}"' return f'changed title from "{old_title}" to "{new_title}"'
return f'performed action {self.event_type.name}' # noqa
return f"performed action {self.event_type.name}" # noqa
def _tag_event_description(self) -> str: def _tag_event_description(self) -> str:
"""Return a description of a TOPIC_TAG event as a string.""" """Return a description of a TOPIC_TAG event as a string."""
if self.event_type != LogEventType.TOPIC_TAG: if self.event_type != LogEventType.TOPIC_TAG:
raise TypeError raise TypeError
old_tags = set(self.info['old']) # noqa
new_tags = set(self.info['new']) # noqa
old_tags = set(self.info["old"]) # noqa
new_tags = set(self.info["new"]) # noqa
added_tags = new_tags - old_tags added_tags = new_tags - old_tags
removed_tags = old_tags - new_tags removed_tags = old_tags - new_tags
description = ''
description = ""
if added_tags: if added_tags:
tag_str = ', '.join([f"'{tag}'" for tag in added_tags])
tag_str = ", ".join([f"'{tag}'" for tag in added_tags])
if len(added_tags) == 1: if len(added_tags) == 1:
description += f'added tag {tag_str}'
description += f"added tag {tag_str}"
else: else:
description += f'added tags {tag_str}'
description += f"added tags {tag_str}"
if removed_tags: if removed_tags:
description += ' and '
description += " and "
if removed_tags: if removed_tags:
tag_str = ', '.join([f"'{tag}'" for tag in removed_tags])
tag_str = ", ".join([f"'{tag}'" for tag in removed_tags])
if len(removed_tags) == 1: if len(removed_tags) == 1:
description += f'removed tag {tag_str}'
description += f"removed tag {tag_str}"
else: else:
description += f'removed tags {tag_str}'
description += f"removed tags {tag_str}"
return description return description
@event.listens_for(Log.__table__, 'after_create')
@event.listens_for(Log.__table__, "after_create")
def create_inherited_tables( def create_inherited_tables(
target: Table,
connection: Connection,
**kwargs: Any,
target: Table, connection: Connection, **kwargs: Any
) -> None: ) -> None:
"""Create all the tables that inherit from the base "log" one.""" """Create all the tables that inherit from the base "log" one."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -180,43 +171,39 @@ def create_inherited_tables(
# log_topics # log_topics
connection.execute( connection.execute(
'CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)')
"CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)"
)
fk_name = naming['fk'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
'referred_table_name': 'topics',
fk_name = naming["fk"] % {
"table_name": "log_topics",
"column_0_name": "topic_id",
"referred_table_name": "topics",
} }
connection.execute( connection.execute(
f'ALTER TABLE log_topics ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (topic_id) REFERENCES topics (topic_id)'
f"ALTER TABLE log_topics ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (topic_id) REFERENCES topics (topic_id)"
) )
ix_name = naming['ix'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
}
connection.execute(f'CREATE INDEX {ix_name} ON log_topics (topic_id)')
ix_name = naming["ix"] % {"table_name": "log_topics", "column_0_name": "topic_id"}
connection.execute(f"CREATE INDEX {ix_name} ON log_topics (topic_id)")
# duplicate all the indexes/constraints from the base log table # duplicate all the indexes/constraints from the base log table
for table in Log.INHERITED_TABLES: for table in Log.INHERITED_TABLES:
pk_name = naming['pk'] % {'table_name': table}
pk_name = naming["pk"] % {"table_name": table}
connection.execute( connection.execute(
f'ALTER TABLE {table} '
f'ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)"
) )
for col in ('event_time', 'event_type', 'ip_address', 'user_id'):
ix_name = naming['ix'] % {
'table_name': table, 'column_0_name': col}
connection.execute(f'CREATE INDEX {ix_name} ON {table} ({col})')
for col in ("event_time", "event_type", "ip_address", "user_id"):
ix_name = naming["ix"] % {"table_name": table, "column_0_name": col}
connection.execute(f"CREATE INDEX {ix_name} ON {table} ({col})")
fk_name = naming['fk'] % {
'table_name': table,
'column_0_name': 'user_id',
'referred_table_name': 'users',
fk_name = naming["fk"] % {
"table_name": table,
"column_0_name": "user_id",
"referred_table_name": "users",
} }
connection.execute( connection.execute(
f'ALTER TABLE {table} ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (user_id) REFERENCES users (user_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (user_id) REFERENCES users (user_id)"
) )

85
tildes/tildes/models/message/message.py

@ -55,78 +55,70 @@ class MessageConversation(DatabaseModel):
schema_class = MessageConversationSchema schema_class = MessageConversationSchema
__tablename__ = 'message_conversations'
__tablename__ = "message_conversations"
conversation_id: int = Column(Integer, primary_key=True) conversation_id: int = Column(Integer, primary_key=True)
sender_id: int = Column( sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
recipient_id: int = Column( recipient_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
subject: str = Column( subject: str = Column(
Text, Text,
CheckConstraint( CheckConstraint(
f'LENGTH(subject) <= {SUBJECT_MAX_LENGTH}',
name='subject_length',
f"LENGTH(subject) <= {SUBJECT_MAX_LENGTH}", name="subject_length"
), ),
nullable=False, nullable=False,
) )
markdown: str = deferred(Column(Text, nullable=False)) markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False) rendered_html: str = Column(Text, nullable=False)
num_replies: int = Column(Integer, nullable=False, server_default='0')
last_reply_time: Optional[datetime] = Column(
TIMESTAMP(timezone=True), index=True)
num_replies: int = Column(Integer, nullable=False, server_default="0")
last_reply_time: Optional[datetime] = Column(TIMESTAMP(timezone=True), index=True)
unread_user_ids: List[int] = Column( unread_user_ids: List[int] = Column(
ARRAY(Integer), nullable=False, server_default='{}')
ARRAY(Integer), nullable=False, server_default="{}"
)
sender: User = relationship( sender: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[sender_id])
"User", lazy=False, innerjoin=True, foreign_keys=[sender_id]
)
recipient: User = relationship( recipient: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[recipient_id])
replies: Sequence['MessageReply'] = relationship(
'MessageReply', order_by='MessageReply.created_time')
"User", lazy=False, innerjoin=True, foreign_keys=[recipient_id]
)
replies: Sequence["MessageReply"] = relationship(
"MessageReply", order_by="MessageReply.created_time"
)
# Create a GIN index on the unread_user_ids column using the gin__int_ops # Create a GIN index on the unread_user_ids column using the gin__int_ops
# operator class supplied by the intarray module. This should be the best # operator class supplied by the intarray module. This should be the best
# index for "array contains" queries. # index for "array contains" queries.
__table_args__ = ( __table_args__ = (
Index( Index(
'ix_message_conversations_unread_user_ids_gin',
"ix_message_conversations_unread_user_ids_gin",
unread_user_ids, unread_user_ids,
postgresql_using='gin',
postgresql_ops={'unread_user_ids': 'gin__int_ops'},
postgresql_using="gin",
postgresql_ops={"unread_user_ids": "gin__int_ops"},
), ),
) )
def __init__( def __init__(
self,
sender: User,
recipient: User,
subject: str,
markdown: str,
self, sender: User, recipient: User, subject: str, markdown: str
) -> None: ) -> None:
"""Create a new message conversation between two users.""" """Create a new message conversation between two users."""
self.sender_id = sender.user_id self.sender_id = sender.user_id
self.recipient_id = recipient.user_id self.recipient_id = recipient.user_id
self.unread_user_ids = ([self.recipient_id])
self.unread_user_ids = [self.recipient_id]
self.subject = subject self.subject = subject
self.markdown = markdown self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown) self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='conversation')
incr_counter("messages", type="conversation")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]: def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL.""" """Pyramid security ACL."""
@ -163,7 +155,7 @@ class MessageConversation(DatabaseModel):
vice versa. vice versa.
""" """
if not self.is_participant(viewer): if not self.is_participant(viewer):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
if viewer == self.sender: if viewer == self.sender:
return self.recipient return self.recipient
@ -173,7 +165,7 @@ class MessageConversation(DatabaseModel):
def is_unread_by_user(self, user: User) -> bool: def is_unread_by_user(self, user: User) -> bool:
"""Return whether the conversation is unread by the specified user.""" """Return whether the conversation is unread by the specified user."""
if not self.is_participant(user): if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
return user.user_id in self.unread_user_ids return user.user_id in self.unread_user_ids
@ -184,9 +176,9 @@ class MessageConversation(DatabaseModel):
worry about duplicate values, race conditions, etc. worry about duplicate values, race conditions, etc.
""" """
if not self.is_participant(user): if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
union = MessageConversation.unread_user_ids.op('|') # type: ignore
union = MessageConversation.unread_user_ids.op("|") # type: ignore
self.unread_user_ids = union(user.user_id) self.unread_user_ids = union(user.user_id)
def mark_read_for_user(self, user: User) -> None: def mark_read_for_user(self, user: User) -> None:
@ -197,11 +189,12 @@ class MessageConversation(DatabaseModel):
race conditions, etc. race conditions, etc.
""" """
if not self.is_participant(user): if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
user_id = user.user_id user_id = user.user_id
self.unread_user_ids = ( # type: ignore self.unread_user_ids = ( # type: ignore
MessageConversation.unread_user_ids - user_id) # type: ignore
MessageConversation.unread_user_ids - user_id # type: ignore
)
class MessageReply(DatabaseModel): class MessageReply(DatabaseModel):
@ -217,37 +210,31 @@ class MessageReply(DatabaseModel):
schema_class = MessageReplySchema schema_class = MessageReplySchema
__tablename__ = 'message_replies'
__tablename__ = "message_replies"
reply_id: int = Column(Integer, primary_key=True) reply_id: int = Column(Integer, primary_key=True)
conversation_id: int = Column( conversation_id: int = Column(
Integer, Integer,
ForeignKey('message_conversations.conversation_id'),
ForeignKey("message_conversations.conversation_id"),
nullable=False, nullable=False,
index=True, index=True,
) )
sender_id: int = Column( sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
markdown: str = deferred(Column(Text, nullable=False)) markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False) rendered_html: str = Column(Text, nullable=False)
sender: User = relationship('User', lazy=False, innerjoin=True)
sender: User = relationship("User", lazy=False, innerjoin=True)
def __init__( def __init__(
self,
conversation: MessageConversation,
sender: User,
markdown: str,
self, conversation: MessageConversation, sender: User, markdown: str
) -> None: ) -> None:
"""Add a new reply to a message conversation.""" """Add a new reply to a message conversation."""
self.conversation_id = conversation.conversation_id self.conversation_id = conversation.conversation_id
@ -255,7 +242,7 @@ class MessageReply(DatabaseModel):
self.markdown = markdown self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown) self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='reply')
incr_counter("messages", type="reply")
@property @property
def reply_id36(self) -> str: def reply_id36(self) -> str:

37
tildes/tildes/models/model_query.py

@ -8,7 +8,7 @@ from sqlalchemy.orm import Load, undefer
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class ModelQuery(Query): class ModelQuery(Query):
@ -22,10 +22,10 @@ class ModelQuery(Query):
self.request = request self.request = request
# can only filter deleted items if the table has an 'is_deleted' column # can only filter deleted items if the table has an 'is_deleted' column
self.filter_deleted = bool('is_deleted' in model_cls.__table__.columns)
self.filter_deleted = bool("is_deleted" in model_cls.__table__.columns)
# can only filter removed items if the table has an 'is_removed' column # can only filter removed items if the table has an 'is_removed' column
self.filter_removed = bool('is_removed' in model_cls.__table__.columns)
self.filter_removed = bool("is_removed" in model_cls.__table__.columns)
def __iter__(self) -> Iterator[ModelType]: def __iter__(self) -> Iterator[ModelType]:
"""Iterate over the (processed) results of the query. """Iterate over the (processed) results of the query.
@ -36,11 +36,11 @@ class ModelQuery(Query):
results = super().__iter__() results = super().__iter__()
return iter([self._process_result(result) for result in results]) return iter([self._process_result(result) for result in results])
def _attach_extra_data(self) -> 'ModelQuery':
def _attach_extra_data(self) -> "ModelQuery":
"""Override to attach extra data to query before execution.""" """Override to attach extra data to query before execution."""
return self return self
def _finalize(self) -> 'ModelQuery':
def _finalize(self) -> "ModelQuery":
"""Finalize the query before it's executed.""" """Finalize the query before it's executed."""
# pylint: disable=protected-access # pylint: disable=protected-access
@ -49,14 +49,13 @@ class ModelQuery(Query):
# is potentially dangerous, but should be fine with the existing # is potentially dangerous, but should be fine with the existing
# straightforward usage patterns. # straightforward usage patterns.
return ( return (
self
.enable_assertions(False)
self.enable_assertions(False)
._attach_extra_data() ._attach_extra_data()
._filter_deleted_if_necessary() ._filter_deleted_if_necessary()
._filter_removed_if_necessary() ._filter_removed_if_necessary()
) )
def _before_compile_listener(self) -> 'ModelQuery':
def _before_compile_listener(self) -> "ModelQuery":
"""Do any final adjustments to the query before it's compiled. """Do any final adjustments to the query before it's compiled.
Note that this method cannot be overridden by subclasses because of Note that this method cannot be overridden by subclasses because of
@ -65,21 +64,21 @@ class ModelQuery(Query):
""" """
return self._finalize() return self._finalize()
def _filter_deleted_if_necessary(self) -> 'ModelQuery':
def _filter_deleted_if_necessary(self) -> "ModelQuery":
"""Filter out deleted rows unless they were explicitly included.""" """Filter out deleted rows unless they were explicitly included."""
if not self.filter_deleted: if not self.filter_deleted:
return self return self
return self.filter(self.model_cls.is_deleted == False) # noqa return self.filter(self.model_cls.is_deleted == False) # noqa
def _filter_removed_if_necessary(self) -> 'ModelQuery':
def _filter_removed_if_necessary(self) -> "ModelQuery":
"""Filter out removed rows unless they were explicitly included.""" """Filter out removed rows unless they were explicitly included."""
if not self.filter_removed: if not self.filter_removed:
return self return self
return self.filter(self.model_cls.is_removed == False) # noqa return self.filter(self.model_cls.is_removed == False) # noqa
def lock_based_on_request_method(self) -> 'ModelQuery':
def lock_based_on_request_method(self) -> "ModelQuery":
"""Lock the rows if request method implies it's needed (generative). """Lock the rows if request method implies it's needed (generative).
Applying this function to a query will cause the database to acquire Applying this function to a query will cause the database to acquire
@ -90,37 +89,37 @@ class ModelQuery(Query):
Note that POST is specifically not included, because the item being Note that POST is specifically not included, because the item being
POSTed to is not usually modified in a "dangerous" way as a result. POSTed to is not usually modified in a "dangerous" way as a result.
""" """
if self.request.method in {'DELETE', 'PATCH', 'PUT'}:
if self.request.method in {"DELETE", "PATCH", "PUT"}:
return self.with_for_update(of=self.model_cls) return self.with_for_update(of=self.model_cls)
return self return self
def include_deleted(self) -> 'ModelQuery':
def include_deleted(self) -> "ModelQuery":
"""Specify that deleted rows should be included (generative).""" """Specify that deleted rows should be included (generative)."""
self.filter_deleted = False self.filter_deleted = False
return self return self
def include_removed(self) -> 'ModelQuery':
def include_removed(self) -> "ModelQuery":
"""Specify that removed rows should be included (generative).""" """Specify that removed rows should be included (generative)."""
self.filter_removed = False self.filter_removed = False
return self return self
def join_all_relationships(self) -> 'ModelQuery':
def join_all_relationships(self) -> "ModelQuery":
"""Eagerly join all lazy relationships (generative). """Eagerly join all lazy relationships (generative).
This is useful for being able to load an item "fully" in a single This is useful for being able to load an item "fully" in a single
query and avoid needing to make additional queries for related items. query and avoid needing to make additional queries for related items.
""" """
# pylint: disable=no-member # pylint: disable=no-member
self = self.options(Load(self.model_cls).joinedload('*'))
self = self.options(Load(self.model_cls).joinedload("*"))
return self return self
def undefer_all_columns(self) -> 'ModelQuery':
def undefer_all_columns(self) -> "ModelQuery":
"""Undefer all columns (generative).""" """Undefer all columns (generative)."""
self = self.options(undefer('*'))
self = self.options(undefer("*"))
return self return self
@ -134,7 +133,7 @@ class ModelQuery(Query):
# before the query executes # before the query executes
event.listen( event.listen(
ModelQuery, ModelQuery,
'before_compile',
"before_compile",
ModelQuery._before_compile_listener, # pylint: disable=protected-access ModelQuery._before_compile_listener, # pylint: disable=protected-access
retval=True, retval=True,
) )

14
tildes/tildes/models/pagination.py

@ -9,7 +9,7 @@ from tildes.lib.id import id_to_id36, id36_to_id
from .model_query import ModelQuery from .model_query import ModelQuery
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class PaginatedQuery(ModelQuery): class PaginatedQuery(ModelQuery):
@ -18,7 +18,7 @@ class PaginatedQuery(ModelQuery):
def __init__(self, model_cls: Any, request: Request) -> None: def __init__(self, model_cls: Any, request: Request) -> None:
"""Initialize a PaginatedQuery for the specified model and request.""" """Initialize a PaginatedQuery for the specified model and request."""
if len(model_cls.__table__.primary_key) > 1: if len(model_cls.__table__.primary_key) > 1:
raise TypeError('Only single-col primary key tables are supported')
raise TypeError("Only single-col primary key tables are supported")
super().__init__(model_cls, request) super().__init__(model_cls, request)
@ -75,7 +75,7 @@ class PaginatedQuery(ModelQuery):
""" """
return bool(self.before_id) return bool(self.before_id)
def after_id36(self, id36: str) -> 'PaginatedQuery':
def after_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results after an id36 (generative).""" """Restrict the query to results after an id36 (generative)."""
if self.before_id: if self.before_id:
raise ValueError("Can't set both before and after restrictions") raise ValueError("Can't set both before and after restrictions")
@ -84,7 +84,7 @@ class PaginatedQuery(ModelQuery):
return self return self
def before_id36(self, id36: str) -> 'PaginatedQuery':
def before_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results before an id36 (generative).""" """Restrict the query to results before an id36 (generative)."""
if self.after_id: if self.after_id:
raise ValueError("Can't set both before and after restrictions") raise ValueError("Can't set both before and after restrictions")
@ -93,7 +93,7 @@ class PaginatedQuery(ModelQuery):
return self return self
def _apply_before_or_after(self) -> 'PaginatedQuery':
def _apply_before_or_after(self) -> "PaginatedQuery":
"""Apply the "before" or "after" restrictions if necessary.""" """Apply the "before" or "after" restrictions if necessary."""
if not (self.after_id or self.before_id): if not (self.after_id or self.before_id):
return self return self
@ -132,7 +132,7 @@ class PaginatedQuery(ModelQuery):
return query return query
def _finalize(self) -> 'PaginatedQuery':
def _finalize(self) -> "PaginatedQuery":
"""Finalize the query before execution.""" """Finalize the query before execution."""
query = super()._finalize() query = super()._finalize()
@ -152,7 +152,7 @@ class PaginatedQuery(ModelQuery):
return query return query
def get_page(self, per_page: int) -> 'PaginatedResults':
def get_page(self, per_page: int) -> "PaginatedResults":
"""Get a page worth of results from the query (`per page` items).""" """Get a page worth of results from the query (`per page` items)."""
return PaginatedResults(self, per_page) return PaginatedResults(self, per_page)

174
tildes/tildes/models/topic/topic.py

@ -31,17 +31,14 @@ from tildes.metrics import incr_counter
from tildes.models import DatabaseModel from tildes.models import DatabaseModel
from tildes.models.group import Group from tildes.models.group import Group
from tildes.models.user import User from tildes.models.user import User
from tildes.schemas.topic import (
TITLE_MAX_LENGTH,
TopicSchema,
)
from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
# edits inside this period after creation will not mark the topic as edited # edits inside this period after creation will not mark the topic as edited
EDIT_GRACE_PERIOD = timedelta(minutes=5) EDIT_GRACE_PERIOD = timedelta(minutes=5)
# special tags to put at the front of the tag list # special tags to put at the front of the tag list
SPECIAL_TAGS = ['nsfw', 'spoiler']
SPECIAL_TAGS = ["nsfw", "spoiler"]
class Topic(DatabaseModel): class Topic(DatabaseModel):
@ -64,76 +61,67 @@ class Topic(DatabaseModel):
schema_class = TopicSchema schema_class = TopicSchema
__tablename__ = 'topics'
__tablename__ = "topics"
topic_id: int = Column(Integer, primary_key=True) topic_id: int = Column(Integer, primary_key=True)
group_id: int = Column( group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
index=True,
Integer, ForeignKey("groups.group_id"), nullable=False, index=True
) )
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_activity_time: datetime = Column( last_activity_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
is_deleted: bool = Column( is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column( is_removed: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True)) removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
title: str = Column( title: str = Column(
Text, Text,
CheckConstraint(
f'LENGTH(title) <= {TITLE_MAX_LENGTH}',
name='title_length',
),
CheckConstraint(f"LENGTH(title) <= {TITLE_MAX_LENGTH}", name="title_length"),
nullable=False, nullable=False,
) )
topic_type: TopicType = Column( topic_type: TopicType = Column(
ENUM(TopicType), nullable=False, server_default='TEXT')
_markdown: Optional[str] = deferred(Column('markdown', Text))
ENUM(TopicType), nullable=False, server_default="TEXT"
)
_markdown: Optional[str] = deferred(Column("markdown", Text))
rendered_html: Optional[str] = Column(Text) rendered_html: Optional[str] = Column(Text)
link: Optional[str] = Column(Text) link: Optional[str] = Column(Text)
content_metadata: Dict[str, Any] = Column(JSONB) content_metadata: Dict[str, Any] = Column(JSONB)
num_comments: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_comments: int = Column(Integer, nullable=False, server_default="0", index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
_tags: List[Ltree] = Column( _tags: List[Ltree] = Column(
'tags', ArrayOfLtree, nullable=False, server_default='{}')
is_official: bool = Column(Boolean, nullable=False, server_default='false')
is_locked: bool = Column(Boolean, nullable=False, server_default='false')
"tags", ArrayOfLtree, nullable=False, server_default="{}"
)
is_official: bool = Column(Boolean, nullable=False, server_default="false")
is_locked: bool = Column(Boolean, nullable=False, server_default="false")
user: User = relationship('User', lazy=False, innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", lazy=False, innerjoin=True)
group: Group = relationship("Group", innerjoin=True)
# Create a GiST index on the tags column # Create a GiST index on the tags column
__table_args__ = (
Index('ix_topics_tags_gist', _tags, postgresql_using='gist'),
)
__table_args__ = (Index("ix_topics_tags_gist", _tags, postgresql_using="gist"),)
@hybrid_property @hybrid_property
def markdown(self) -> Optional[str]: def markdown(self) -> Optional[str]:
"""Return the topic's markdown.""" """Return the topic's markdown."""
if not self.is_text_type: if not self.is_text_type:
raise AttributeError('Only text topics have markdown')
raise AttributeError("Only text topics have markdown")
return self._markdown return self._markdown
@ -141,7 +129,7 @@ class Topic(DatabaseModel):
def markdown(self, new_markdown: str) -> None: def markdown(self, new_markdown: str) -> None:
"""Set the topic's markdown and render its HTML.""" """Set the topic's markdown and render its HTML."""
if not self.is_text_type: if not self.is_text_type:
raise AttributeError('Can only set markdown for text topics')
raise AttributeError("Can only set markdown for text topics")
if new_markdown == self.markdown: if new_markdown == self.markdown:
return return
@ -149,21 +137,19 @@ class Topic(DatabaseModel):
self._markdown = new_markdown self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown) self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now() self.last_edited_time = utc_now()
@hybrid_property @hybrid_property
def tags(self) -> List[str]: def tags(self) -> List[str]:
"""Return the topic's tags.""" """Return the topic's tags."""
sorted_tags = [str(tag).replace('_', ' ') for tag in self._tags]
sorted_tags = [str(tag).replace("_", " ") for tag in self._tags]
# move special tags in front # move special tags in front
# reverse so that tags at the start of the list appear first # reverse so that tags at the start of the list appear first
for tag in reversed(SPECIAL_TAGS): for tag in reversed(SPECIAL_TAGS):
if tag in sorted_tags: if tag in sorted_tags:
sorted_tags.insert(
0, sorted_tags.pop(sorted_tags.index(tag)))
sorted_tags.insert(0, sorted_tags.pop(sorted_tags.index(tag)))
return sorted_tags return sorted_tags
@ -176,12 +162,7 @@ class Topic(DatabaseModel):
return f'<Topic "{self.title}" ({self.topic_id})>' return f'<Topic "{self.title}" ({self.topic_id})>'
@classmethod @classmethod
def _create_base_topic(
cls,
group: Group,
author: User,
title: str,
) -> 'Topic':
def _create_base_topic(cls, group: Group, author: User, title: str) -> "Topic":
"""Create the "base" for a new topic.""" """Create the "base" for a new topic."""
new_topic = cls() new_topic = cls()
new_topic.group_id = group.group_id new_topic.group_id = group.group_id
@ -192,35 +173,27 @@ class Topic(DatabaseModel):
@classmethod @classmethod
def create_text_topic( def create_text_topic(
cls,
group: Group,
author: User,
title: str,
markdown: str = '',
) -> 'Topic':
cls, group: Group, author: User, title: str, markdown: str = ""
) -> "Topic":
"""Create a new text topic.""" """Create a new text topic."""
new_topic = cls._create_base_topic(group, author, title) new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.TEXT new_topic.topic_type = TopicType.TEXT
new_topic.markdown = markdown new_topic.markdown = markdown
incr_counter('topics', type='text')
incr_counter("topics", type="text")
return new_topic return new_topic
@classmethod @classmethod
def create_link_topic( def create_link_topic(
cls,
group: Group,
author: User,
title: str,
link: str,
) -> 'Topic':
cls, group: Group, author: User, title: str, link: str
) -> "Topic":
"""Create a new link topic.""" """Create a new link topic."""
new_topic = cls._create_base_topic(group, author, title) new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.LINK new_topic.topic_type = TopicType.LINK
new_topic.link = link new_topic.link = link
incr_counter('topics', type='link')
incr_counter("topics", type="link")
return new_topic return new_topic
@ -230,74 +203,74 @@ class Topic(DatabaseModel):
# deleted topics allow "general" viewing, but nothing else # deleted topics allow "general" viewing, but nothing else
if self.is_deleted: if self.is_deleted:
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
acl.append(DENY_ALL) acl.append(DENY_ALL)
# view: # view:
# - everyone gets "general" viewing permission for all topics # - everyone gets "general" viewing permission for all topics
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# view_author: # view_author:
# - removed topics' author is only visible to the author and admins # - removed topics' author is only visible to the author and admins
# - otherwise, everyone can view the author # - otherwise, everyone can view the author
if self.is_removed: if self.is_removed:
acl.append((Allow, 'admin', 'view_author'))
acl.append((Allow, self.user_id, 'view_author'))
acl.append((Deny, Everyone, 'view_author'))
acl.append((Allow, "admin", "view_author"))
acl.append((Allow, self.user_id, "view_author"))
acl.append((Deny, Everyone, "view_author"))
acl.append((Allow, Everyone, 'view_author'))
acl.append((Allow, Everyone, "view_author"))
# view_content: # view_content:
# - removed topics' content is only visible to the author and admins # - removed topics' content is only visible to the author and admins
# - otherwise, everyone can view the content # - otherwise, everyone can view the content
if self.is_removed: if self.is_removed:
acl.append((Allow, 'admin', 'view_content'))
acl.append((Allow, self.user_id, 'view_content'))
acl.append((Deny, Everyone, 'view_content'))
acl.append((Allow, "admin", "view_content"))
acl.append((Allow, self.user_id, "view_content"))
acl.append((Deny, Everyone, "view_content"))
acl.append((Allow, Everyone, 'view_content'))
acl.append((Allow, Everyone, "view_content"))
# vote: # vote:
# - removed topics can't be voted on by anyone # - removed topics can't be voted on by anyone
# - otherwise, logged-in users except the author can vote # - otherwise, logged-in users except the author can vote
if self.is_removed: if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# comment: # comment:
# - removed topics can only be commented on by admins # - removed topics can only be commented on by admins
# - locked topics can only be commented on by admins # - locked topics can only be commented on by admins
# - otherwise, logged-in users can comment # - otherwise, logged-in users can comment
if self.is_removed: if self.is_removed:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
if self.is_locked: if self.is_locked:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
acl.append((Allow, Authenticated, 'comment'))
acl.append((Allow, Authenticated, "comment"))
# edit: # edit:
# - only text topics can be edited, only by the author # - only text topics can be edited, only by the author
if self.is_text_type: if self.is_text_type:
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete: # delete:
# - only the author can delete # - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# tag: # tag:
# - only the author and admins can tag topics # - only the author and admins can tag topics
acl.append((Allow, self.user_id, 'tag'))
acl.append((Allow, 'admin', 'tag'))
acl.append((Allow, self.user_id, "tag"))
acl.append((Allow, "admin", "tag"))
# admin tools # admin tools
acl.append((Allow, 'admin', 'lock'))
acl.append((Allow, 'admin', 'move'))
acl.append((Allow, 'admin', 'edit_title'))
acl.append((Allow, "admin", "lock"))
acl.append((Allow, "admin", "move"))
acl.append((Allow, "admin", "edit_title"))
acl.append(DENY_ALL) acl.append(DENY_ALL)
@ -316,7 +289,7 @@ class Topic(DatabaseModel):
@property @property
def permalink(self) -> str: def permalink(self) -> str:
"""Return the permalink for this topic.""" """Return the permalink for this topic."""
return f'/~{self.group.path}/{self.topic_id36}/{self.url_slug}'
return f"/~{self.group.path}/{self.topic_id36}/{self.url_slug}"
@property @property
def is_text_type(self) -> bool: def is_text_type(self) -> bool:
@ -332,27 +305,26 @@ class Topic(DatabaseModel):
def type_for_display(self) -> str: def type_for_display(self) -> str:
"""Return a string of the topic's type, suitable for display.""" """Return a string of the topic's type, suitable for display."""
if self.is_text_type: if self.is_text_type:
return 'Text'
return "Text"
elif self.is_link_type: elif self.is_link_type:
return 'Link'
return "Link"
return 'Topic'
return "Topic"
@property @property
def link_domain(self) -> str: def link_domain(self) -> str:
"""Return the link's domain (for link topics only).""" """Return the link's domain (for link topics only)."""
if not self.is_link_type or not self.link: if not self.is_link_type or not self.link:
raise ValueError('Non-link topics do not have a domain')
raise ValueError("Non-link topics do not have a domain")
# get the domain from the content metadata if possible, but fall back # get the domain from the content metadata if possible, but fall back
# to just parsing it from the link if it's not present # to just parsing it from the link if it's not present
return (self.get_content_metadata('domain')
or get_domain_from_url(self.link))
return self.get_content_metadata("domain") or get_domain_from_url(self.link)
@property @property
def is_spoiler(self) -> bool: def is_spoiler(self) -> bool:
"""Return whether the topic is marked as a spoiler.""" """Return whether the topic is marked as a spoiler."""
return 'spoiler' in self.tags
return "spoiler" in self.tags
def get_content_metadata(self, key: str) -> Any: def get_content_metadata(self, key: str) -> Any:
"""Get a piece of content metadata "safely". """Get a piece of content metadata "safely".
@ -371,13 +343,13 @@ class Topic(DatabaseModel):
metadata_strings = [] metadata_strings = []
if self.is_text_type: if self.is_text_type:
word_count = self.get_content_metadata('word_count')
word_count = self.get_content_metadata("word_count")
if word_count is not None: if word_count is not None:
if word_count == 1: if word_count == 1:
metadata_strings.append('1 word')
metadata_strings.append("1 word")
else: else:
metadata_strings.append(f'{word_count} words')
metadata_strings.append(f"{word_count} words")
elif self.is_link_type: elif self.is_link_type:
metadata_strings.append(f'{self.link_domain}')
metadata_strings.append(f"{self.link_domain}")
return ', '.join(metadata_strings)
return ", ".join(metadata_strings)

45
tildes/tildes/models/topic/topic_query.py

@ -28,7 +28,7 @@ class TopicQuery(PaginatedQuery):
""" """
super().__init__(Topic, request) super().__init__(Topic, request)
def _attach_extra_data(self) -> 'TopicQuery':
def _attach_extra_data(self) -> "TopicQuery":
"""Attach the extra user data to the query.""" """Attach the extra user data to the query."""
if not self.request.user: if not self.request.user:
return self return self
@ -36,7 +36,7 @@ class TopicQuery(PaginatedQuery):
# pylint: disable=protected-access # pylint: disable=protected-access
return self._attach_vote_data()._attach_visit_data() return self._attach_vote_data()._attach_visit_data()
def _attach_vote_data(self) -> 'TopicQuery':
def _attach_vote_data(self) -> "TopicQuery":
"""Add a subquery to include whether the user has voted.""" """Add a subquery to include whether the user has voted."""
vote_subquery = ( vote_subquery = (
self.request.query(TopicVote) self.request.query(TopicVote)
@ -45,24 +45,25 @@ class TopicQuery(PaginatedQuery):
TopicVote.user == self.request.user, TopicVote.user == self.request.user,
) )
.exists() .exists()
.label('user_voted')
.label("user_voted")
) )
return self.add_columns(vote_subquery) return self.add_columns(vote_subquery)
def _attach_visit_data(self) -> 'TopicQuery':
def _attach_visit_data(self) -> "TopicQuery":
"""Join the data related to the user's last visit to the topic(s).""" """Join the data related to the user's last visit to the topic(s)."""
if self.request.user.track_comment_visits: if self.request.user.track_comment_visits:
query = self.outerjoin(TopicVisit, and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
))
query = query.add_columns(
TopicVisit.visit_time, TopicVisit.num_comments)
query = self.outerjoin(
TopicVisit,
and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
),
)
query = query.add_columns(TopicVisit.visit_time, TopicVisit.num_comments)
else: else:
# if the user has the feature disabled, just add literal NULLs # if the user has the feature disabled, just add literal NULLs
query = self.add_columns( query = self.add_columns(
null().label('visit_time'),
null().label('num_comments'),
null().label("visit_time"), null().label("num_comments")
) )
return query return query
@ -90,10 +91,8 @@ class TopicQuery(PaginatedQuery):
return topic return topic
def apply_sort_option( def apply_sort_option(
self,
sort: TopicSortOption,
desc: bool = True,
) -> 'TopicQuery':
self, sort: TopicSortOption, desc: bool = True
) -> "TopicQuery":
"""Apply a TopicSortOption sorting method (generative).""" """Apply a TopicSortOption sorting method (generative)."""
if sort == TopicSortOption.VOTES: if sort == TopicSortOption.VOTES:
self._sort_column = Topic.num_votes self._sort_column = Topic.num_votes
@ -108,18 +107,16 @@ class TopicQuery(PaginatedQuery):
return self return self
def inside_groups(self, groups: Sequence[Group]) -> 'TopicQuery':
def inside_groups(self, groups: Sequence[Group]) -> "TopicQuery":
"""Restrict the topics to inside specific groups (generative).""" """Restrict the topics to inside specific groups (generative)."""
query_paths = [group.path for group in groups] query_paths = [group.path for group in groups]
subgroup_subquery = (
self.request.db_session.query(Group.group_id)
.filter(Group.path.descendant_of(query_paths))
subgroup_subquery = self.request.db_session.query(Group.group_id).filter(
Group.path.descendant_of(query_paths)
) )
return self.filter(
Topic.group_id.in_(subgroup_subquery)) # type: ignore
return self.filter(Topic.group_id.in_(subgroup_subquery)) # type: ignore
def inside_time_period(self, period: SimpleHoursPeriod) -> 'TopicQuery':
def inside_time_period(self, period: SimpleHoursPeriod) -> "TopicQuery":
"""Restrict the topics to inside a time period (generative).""" """Restrict the topics to inside a time period (generative)."""
# if the time period is too long, this will crash by creating a # if the time period is too long, this will crash by creating a
# datetime outside the valid range - catch that and just don't filter # datetime outside the valid range - catch that and just don't filter
@ -131,7 +128,7 @@ class TopicQuery(PaginatedQuery):
return self.filter(Topic.created_time > start_time) return self.filter(Topic.created_time > start_time)
def has_tag(self, tag: Ltree) -> 'TopicQuery':
def has_tag(self, tag: Ltree) -> "TopicQuery":
"""Restrict the topics to ones with a specific tag (generative).""" """Restrict the topics to ones with a specific tag (generative)."""
# casting tag to string really shouldn't be necessary, but some kind of # casting tag to string really shouldn't be necessary, but some kind of
# strange interaction seems to be happening with the ArrayOfLtree # strange interaction seems to be happening with the ArrayOfLtree

26
tildes/tildes/models/topic/topic_visit.py

@ -28,28 +28,19 @@ class TopicVisit(DatabaseModel):
visits to the topic that were after it was posted. visits to the topic that were after it was posted.
""" """
__tablename__ = 'topic_visits'
__tablename__ = "topic_visits"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
topic_id: int = Column( topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
)
visit_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
) )
visit_time: datetime = Column(TIMESTAMP(timezone=True), nullable=False)
num_comments: int = Column(Integer, nullable=False) num_comments: int = Column(Integer, nullable=False)
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
@classmethod @classmethod
def generate_insert_statement(cls, user: User, topic: Topic) -> Insert: def generate_insert_statement(cls, user: User, topic: Topic) -> Insert:
@ -65,9 +56,6 @@ class TopicVisit(DatabaseModel):
) )
.on_conflict_do_update( .on_conflict_do_update(
constraint=cls.__table__.primary_key, constraint=cls.__table__.primary_key,
set_={
'visit_time': visit_time,
'num_comments': topic.num_comments,
},
set_={"visit_time": visit_time, "num_comments": topic.num_comments},
) )
) )

20
tildes/tildes/models/topic/topic_vote.py

@ -21,33 +21,27 @@ class TopicVote(DatabaseModel):
column for the relevant topic. column for the relevant topic.
""" """
__tablename__ = 'topic_votes'
__tablename__ = "topic_votes"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
topic_id: int = Column( topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
def __init__(self, user: User, topic: Topic) -> None: def __init__(self, user: User, topic: Topic) -> None:
"""Create a new vote on a topic.""" """Create a new vote on a topic."""
self.user = user self.user = user
self.topic = topic self.topic = topic
incr_counter('votes', target_type='topic')
incr_counter("votes", target_type="topic")

70
tildes/tildes/models/user/user.py

@ -49,7 +49,7 @@ class User(DatabaseModel):
schema_class = UserSchema schema_class = UserSchema
__tablename__ = 'users'
__tablename__ = "users"
user_id: int = Column(Integer, primary_key=True) user_id: int = Column(Integer, primary_key=True)
username: str = Column(CIText, nullable=False, unique=True) username: str = Column(CIText, nullable=False, unique=True)
@ -59,9 +59,8 @@ class User(DatabaseModel):
Column( Column(
Text, Text,
CheckConstraint( CheckConstraint(
'LENGTH(email_address_note) <= '
f'{EMAIL_ADDRESS_NOTE_MAX_LENGTH}',
name='email_address_note_length',
f"LENGTH(email_address_note) <= {EMAIL_ADDRESS_NOTE_MAX_LENGTH}",
name="email_address_note_length",
), ),
) )
) )
@ -69,44 +68,35 @@ class User(DatabaseModel):
TIMESTAMP(timezone=True), TIMESTAMP(timezone=True),
nullable=False, nullable=False,
index=True, index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
) )
num_unread_messages: int = Column(
Integer, nullable=False, server_default='0')
num_unread_notifications: int = Column(
Integer, nullable=False, server_default='0')
inviter_id: int = Column(Integer, ForeignKey('users.user_id'))
invite_codes_remaining: int = Column(
Integer, nullable=False, server_default='0')
track_comment_visits: bool = Column(
Boolean, nullable=False, server_default='false')
num_unread_messages: int = Column(Integer, nullable=False, server_default="0")
num_unread_notifications: int = Column(Integer, nullable=False, server_default="0")
inviter_id: int = Column(Integer, ForeignKey("users.user_id"))
invite_codes_remaining: int = Column(Integer, nullable=False, server_default="0")
track_comment_visits: bool = Column(Boolean, nullable=False, server_default="false")
auto_mark_notifications_read: bool = Column( auto_mark_notifications_read: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_external: bool = Column( open_new_tab_external: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_internal: bool = Column( open_new_tab_internal: bool = Column(
Boolean, nullable=False, server_default='false')
open_new_tab_text: bool = Column(
Boolean, nullable=False, server_default='false')
is_banned: bool = Column(Boolean, nullable=False, server_default='false')
is_admin: bool = Column(Boolean, nullable=False, server_default='false')
home_default_order: Optional[TopicSortOption] = Column(
ENUM(TopicSortOption))
Boolean, nullable=False, server_default="false"
)
open_new_tab_text: bool = Column(Boolean, nullable=False, server_default="false")
is_banned: bool = Column(Boolean, nullable=False, server_default="false")
is_admin: bool = Column(Boolean, nullable=False, server_default="false")
home_default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
home_default_period: Optional[str] = Column(Text) home_default_period: Optional[str] = Column(Text)
_filtered_topic_tags: List[Ltree] = Column( _filtered_topic_tags: List[Ltree] = Column(
'filtered_topic_tags',
ArrayOfLtree,
nullable=False,
server_default='{}',
"filtered_topic_tags", ArrayOfLtree, nullable=False, server_default="{}"
) )
@hybrid_property @hybrid_property
def filtered_topic_tags(self) -> List[str]: def filtered_topic_tags(self) -> List[str]:
"""Return the user's list of filtered topic tags.""" """Return the user's list of filtered topic tags."""
return [
str(tag).replace('_', ' ')
for tag in self._filtered_topic_tags
]
return [str(tag).replace("_", " ") for tag in self._filtered_topic_tags]
@filtered_topic_tags.setter # type: ignore @filtered_topic_tags.setter # type: ignore
def filtered_topic_tags(self, new_tags: List[str]) -> None: def filtered_topic_tags(self, new_tags: List[str]) -> None:
@ -114,7 +104,7 @@ class User(DatabaseModel):
def __repr__(self) -> str: def __repr__(self) -> str:
"""Display the user's username and ID as its repr format.""" """Display the user's username and ID as its repr format."""
return f'<User {self.username} ({self.user_id})>'
return f"<User {self.username} ({self.user_id})>"
def __str__(self) -> str: def __str__(self) -> str:
"""Use the username for the string representation.""" """Use the username for the string representation."""
@ -131,12 +121,12 @@ class User(DatabaseModel):
# view: # view:
# - everyone can view all users # - everyone can view all users
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# message: # message:
# - anyone can message a user except themself # - anyone can message a user except themself
acl.append((Deny, self.user_id, 'message'))
acl.append((Allow, Authenticated, 'message'))
acl.append((Deny, self.user_id, "message"))
acl.append((Allow, Authenticated, "message"))
# grant the user all other permissions on themself # grant the user all other permissions on themself
acl.append((Allow, self.user_id, ALL_PERMISSIONS)) acl.append((Allow, self.user_id, ALL_PERMISSIONS))
@ -148,13 +138,13 @@ class User(DatabaseModel):
@property @property
def password(self) -> NoReturn: def password(self) -> NoReturn:
"""Return an error since reading the password isn't possible.""" """Return an error since reading the password isn't possible."""
raise AttributeError('Password is write-only')
raise AttributeError("Password is write-only")
@password.setter @password.setter
def password(self, value: str) -> None: def password(self, value: str) -> None:
# need to do manual validation since some password checks depend on # need to do manual validation since some password checks depend on
# checking the username at the same time (for similarity) # checking the username at the same time (for similarity)
self.schema.validate({'username': self.username, 'password': value})
self.schema.validate({"username": self.username, "password": value})
self.password_hash = hash_string(value) self.password_hash = hash_string(value)
@ -165,10 +155,10 @@ class User(DatabaseModel):
def change_password(self, old_password: str, new_password: str) -> None: def change_password(self, old_password: str, new_password: str) -> None:
"""Change the user's password from the old one to a new one.""" """Change the user's password from the old one to a new one."""
if not self.is_correct_password(old_password): if not self.is_correct_password(old_password):
raise ValueError('Old password was not correct')
raise ValueError("Old password was not correct")
if new_password == old_password: if new_password == old_password:
raise ValueError('New password is the same as old password')
raise ValueError("New password is the same as old password")
# disable mypy on this line because it doesn't handle setters correctly # disable mypy on this line because it doesn't handle setters correctly
self.password = new_password # type: ignore self.password = new_password # type: ignore
@ -176,7 +166,7 @@ class User(DatabaseModel):
@property @property
def email_address(self) -> NoReturn: def email_address(self) -> NoReturn:
"""Return an error since reading the email address isn't possible.""" """Return an error since reading the email address isn't possible."""
raise AttributeError('Email address is write-only')
raise AttributeError("Email address is write-only")
@email_address.setter @email_address.setter
def email_address(self, value: Optional[str]) -> None: def email_address(self, value: Optional[str]) -> None:

16
tildes/tildes/models/user/user_group_settings.py

@ -15,22 +15,16 @@ from tildes.models.user import User
class UserGroupSettings(DatabaseModel): class UserGroupSettings(DatabaseModel):
"""Model for a user's settings related to a specific group.""" """Model for a user's settings related to a specific group."""
__tablename__ = 'user_group_settings'
__tablename__ = "user_group_settings"
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
) )
group_id: int = Column( group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
) )
default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption)) default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
default_period: Optional[str] = Column(Text) default_period: Optional[str] = Column(Text)
user: User = relationship('User', innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", innerjoin=True)
group: Group = relationship("Group", innerjoin=True)

37
tildes/tildes/models/user/user_invite_code.py

@ -4,14 +4,7 @@ from datetime import datetime
import random import random
import string import string
from sqlalchemy import (
CheckConstraint,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import CheckConstraint, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text from sqlalchemy.sql.expression import text
from tildes.models import DatabaseModel from tildes.models import DatabaseModel
@ -21,7 +14,7 @@ from .user import User
class UserInviteCode(DatabaseModel): class UserInviteCode(DatabaseModel):
"""Model for invite codes that allow new users to register.""" """Model for invite codes that allow new users to register."""
__tablename__ = 'user_invite_codes'
__tablename__ = "user_invite_codes"
# the character set to generate codes using # the character set to generate codes using
ALPHABET = string.ascii_uppercase + string.digits ALPHABET = string.ascii_uppercase + string.digits
@ -30,33 +23,25 @@ class UserInviteCode(DatabaseModel):
code: str = Column( code: str = Column(
Text, Text,
CheckConstraint(
f'LENGTH(code) <= {LENGTH}',
name='code_length',
),
CheckConstraint(f"LENGTH(code) <= {LENGTH}", name="code_length"),
primary_key=True, primary_key=True,
) )
user_id: int = Column( user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
) )
created_time: datetime = Column( created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
) )
invitee_id: int = Column(Integer, ForeignKey('users.user_id'))
invitee_id: int = Column(Integer, ForeignKey("users.user_id"))
def __str__(self) -> str: def __str__(self) -> str:
"""Format the code into a more easily readable version.""" """Format the code into a more easily readable version."""
formatted = ''
formatted = ""
for count, char in enumerate(self.code): for count, char in enumerate(self.code):
# add a dash every 5 chars # add a dash every 5 chars
if count > 0 and count % 5 == 0: if count > 0 and count % 5 == 0:
formatted += '-'
formatted += "-"
formatted += char.upper() formatted += char.upper()
@ -71,7 +56,7 @@ class UserInviteCode(DatabaseModel):
self.user_id = user.user_id self.user_id = user.user_id
code_chars = random.choices(self.ALPHABET, k=self.LENGTH) code_chars = random.choices(self.ALPHABET, k=self.LENGTH)
self.code = ''.join(code_chars)
self.code = "".join(code_chars)
@classmethod @classmethod
def prepare_code_for_lookup(cls, code: str) -> str: def prepare_code_for_lookup(cls, code: str) -> str:
@ -81,9 +66,9 @@ class UserInviteCode(DatabaseModel):
# remove any characters that aren't in the code alphabet (allows # remove any characters that aren't in the code alphabet (allows
# dashes, spaces, etc. to be used to make the codes more readable) # dashes, spaces, etc. to be used to make the codes more readable)
code = ''.join(letter for letter in code if letter in cls.ALPHABET)
code = "".join(letter for letter in code if letter in cls.ALPHABET)
if len(code) > cls.LENGTH: if len(code) > cls.LENGTH:
raise ValueError('Code is longer than the maximum length')
raise ValueError("Code is longer than the maximum length")
return code return code

6
tildes/tildes/resources/__init__.py

@ -14,11 +14,7 @@ def get_resource(request: Request, base_query: ModelQuery) -> DatabaseModel:
if not request.user: if not request.user:
raise HTTPForbidden raise HTTPForbidden
query = (
base_query
.lock_based_on_request_method()
.join_all_relationships()
)
query = base_query.lock_based_on_request_method().join_all_relationships()
if not request.is_safe_method: if not request.is_safe_method:
query = query.undefer_all_columns() query = query.undefer_all_columns()

16
tildes/tildes/resources/comment.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.comment import CommentSchema from tildes.schemas.comment import CommentSchema
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def comment_by_id36(request: Request, comment_id36: str) -> Comment: def comment_by_id36(request: Request, comment_id36: str) -> Comment:
"""Get a comment specified by {comment_id36} in the route (or 404).""" """Get a comment specified by {comment_id36} in the route (or 404)."""
query = ( query = (
@ -25,13 +22,9 @@ def comment_by_id36(request: Request, comment_id36: str) -> Comment:
return get_resource(request, query) return get_resource(request, query)
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def notification_by_comment_id36( def notification_by_comment_id36(
request: Request,
comment_id36: str,
request: Request, comment_id36: str
) -> CommentNotification: ) -> CommentNotification:
"""Get a comment notification specified by {comment_id36} in the route. """Get a comment notification specified by {comment_id36} in the route.
@ -43,8 +36,7 @@ def notification_by_comment_id36(
comment_id = id36_to_id(comment_id36) comment_id = id36_to_id(comment_id36)
query = request.query(CommentNotification).filter_by( query = request.query(CommentNotification).filter_by(
user=request.user,
comment_id=comment_id,
user=request.user, comment_id=comment_id
) )
return get_resource(request, query) return get_resource(request, query)

11
tildes/tildes/resources/group.py

@ -11,8 +11,8 @@ from tildes.schemas.group import GroupSchema
@use_kwargs( @use_kwargs(
GroupSchema(only=('path',), context={'fix_path_capitalization': True}),
locations=('matchdict',),
GroupSchema(only=("path",), context={"fix_path_capitalization": True}),
locations=("matchdict",),
) )
def group_by_path(request: Request, path: str) -> Group: def group_by_path(request: Request, path: str) -> Group:
"""Get a group specified by {group_path} in the route (or 404).""" """Get a group specified by {group_path} in the route (or 404)."""
@ -20,10 +20,9 @@ def group_by_path(request: Request, path: str) -> Group:
# 301 redirect to the resulting group path. This will happen in cases like # 301 redirect to the resulting group path. This will happen in cases like
# the original url including capital letters in the group path, where we # the original url including capital letters in the group path, where we
# want to redirect to the proper all-lowercase path instead. # want to redirect to the proper all-lowercase path instead.
if path != request.matchdict['group_path']:
request.matchdict['group_path'] = path
proper_url = request.route_url(
request.matched_route.name, **request.matchdict)
if path != request.matchdict["group_path"]:
request.matchdict["group_path"] = path
proper_url = request.route_url(request.matched_route.name, **request.matchdict)
raise HTTPMovedPermanently(location=proper_url) raise HTTPMovedPermanently(location=proper_url)

11
tildes/tildes/resources/message.py

@ -10,17 +10,14 @@ from tildes.schemas.message import MessageConversationSchema
@use_kwargs( @use_kwargs(
MessageConversationSchema(only=('conversation_id36',)),
locations=('matchdict',),
MessageConversationSchema(only=("conversation_id36",)), locations=("matchdict",)
) )
def message_conversation_by_id36( def message_conversation_by_id36(
request: Request,
conversation_id36: str,
request: Request, conversation_id36: str
) -> MessageConversation: ) -> MessageConversation:
"""Get a conversation specified by {conversation_id36} in the route.""" """Get a conversation specified by {conversation_id36} in the route."""
query = (
request.query(MessageConversation)
.filter_by(conversation_id=id36_to_id(conversation_id36))
query = request.query(MessageConversation).filter_by(
conversation_id=id36_to_id(conversation_id36)
) )
return get_resource(request, query) return get_resource(request, query)

9
tildes/tildes/resources/topic.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.topic import TopicSchema from tildes.schemas.topic import TopicSchema
@use_kwargs(
TopicSchema(only=('topic_id36',)),
locations=('matchdict',),
)
@use_kwargs(TopicSchema(only=("topic_id36",)), locations=("matchdict",))
def topic_by_id36(request: Request, topic_id36: str) -> Topic: def topic_by_id36(request: Request, topic_id36: str) -> Topic:
"""Get a topic specified by {topic_id36} in the route (or 404).""" """Get a topic specified by {topic_id36} in the route (or 404)."""
query = ( query = (
@ -27,8 +24,8 @@ def topic_by_id36(request: Request, topic_id36: str) -> Topic:
# if there's also a group specified in the route, check that it's the same # if there's also a group specified in the route, check that it's the same
# group as the topic was posted in, otherwise redirect to correct group # group as the topic was posted in, otherwise redirect to correct group
if 'group_path' in request.matchdict:
path_from_route = request.matchdict['group_path'].lower()
if "group_path" in request.matchdict:
path_from_route = request.matchdict["group_path"].lower()
if path_from_route != topic.group.path: if path_from_route != topic.group.path:
raise HTTPFound(topic.permalink) raise HTTPFound(topic.permalink)

5
tildes/tildes/resources/user.py

@ -8,10 +8,7 @@ from tildes.resources import get_resource
from tildes.schemas.user import UserSchema from tildes.schemas.user import UserSchema
@use_kwargs(
UserSchema(only=('username',)),
locations=('matchdict',),
)
@use_kwargs(UserSchema(only=("username",)), locations=("matchdict",))
def user_by_username(request: Request, username: str) -> User: def user_by_username(request: Request, username: str) -> User:
"""Get a user specified by {username} in the route or 404 if not found.""" """Get a user specified by {username} in the route or 404 if not found."""
query = request.query(User).filter(User.username == username) query = request.query(User).filter(User.username == username)

160
tildes/tildes/routes.py

@ -6,10 +6,7 @@ from pyramid.config import Configurator
from pyramid.request import Request from pyramid.request import Request
from pyramid.security import Allow, Authenticated from pyramid.security import Allow, Authenticated
from tildes.resources.comment import (
comment_by_id36,
notification_by_comment_id36,
)
from tildes.resources.comment import comment_by_id36, notification_by_comment_id36
from tildes.resources.group import group_by_path from tildes.resources.group import group_by_path
from tildes.resources.message import message_conversation_by_id36 from tildes.resources.message import message_conversation_by_id36
from tildes.resources.topic import topic_by_id36 from tildes.resources.topic import topic_by_id36
@ -18,172 +15,133 @@ from tildes.resources.user import user_by_username
def includeme(config: Configurator) -> None: def includeme(config: Configurator) -> None:
"""Set up application routes.""" """Set up application routes."""
config.add_route('home', '/')
config.add_route("home", "/")
config.add_route('groups', '/groups')
config.add_route("groups", "/groups")
config.add_route('login', '/login')
config.add_route('logout', '/logout', factory=LoggedInFactory)
config.add_route("login", "/login")
config.add_route("logout", "/logout", factory=LoggedInFactory)
config.add_route('register', '/register')
config.add_route("register", "/register")
config.add_route('group', '/~{group_path}', factory=group_by_path)
config.add_route(
'new_topic', '/~{group_path}/new_topic', factory=group_by_path)
config.add_route("group", "/~{group_path}", factory=group_by_path)
config.add_route("new_topic", "/~{group_path}/new_topic", factory=group_by_path)
config.add_route(
'group_topics', '/~{group_path}/topics', factory=group_by_path)
config.add_route("group_topics", "/~{group_path}/topics", factory=group_by_path)
config.add_route( config.add_route(
'topic', '/~{group_path}/{topic_id36}*title', factory=topic_by_id36)
"topic", "/~{group_path}/{topic_id36}*title", factory=topic_by_id36
)
config.add_route('user', '/user/{username}', factory=user_by_username)
config.add_route("user", "/user/{username}", factory=user_by_username)
config.add_route("notifications", "/notifications", factory=LoggedInFactory)
config.add_route( config.add_route(
'notifications', '/notifications', factory=LoggedInFactory)
config.add_route(
'notifications_unread',
'/notifications/unread',
factory=LoggedInFactory,
"notifications_unread", "/notifications/unread", factory=LoggedInFactory
) )
config.add_route('messages', '/messages', factory=LoggedInFactory)
config.add_route(
'messages_sent', '/messages/sent', factory=LoggedInFactory)
config.add_route("messages", "/messages", factory=LoggedInFactory)
config.add_route("messages_sent", "/messages/sent", factory=LoggedInFactory)
config.add_route("messages_unread", "/messages/unread", factory=LoggedInFactory)
config.add_route( config.add_route(
'messages_unread', '/messages/unread', factory=LoggedInFactory)
config.add_route(
'message_conversation',
'/messages/conversations/{conversation_id36}',
"message_conversation",
"/messages/conversations/{conversation_id36}",
factory=message_conversation_by_id36, factory=message_conversation_by_id36,
) )
config.add_route( config.add_route(
'new_message',
'/user/{username}/new_message',
factory=user_by_username,
"new_message", "/user/{username}/new_message", factory=user_by_username
) )
config.add_route( config.add_route(
'user_messages',
'/user/{username}/messages',
factory=user_by_username,
"user_messages", "/user/{username}/messages", factory=user_by_username
) )
config.add_route('settings', '/settings', factory=LoggedInFactory)
config.add_route("settings", "/settings", factory=LoggedInFactory)
config.add_route( config.add_route(
'settings_account_recovery',
'/settings/account_recovery',
"settings_account_recovery",
"/settings/account_recovery",
factory=LoggedInFactory, factory=LoggedInFactory,
) )
config.add_route( config.add_route(
'settings_comment_visits',
'/settings/comment_visits',
factory=LoggedInFactory,
"settings_comment_visits", "/settings/comment_visits", factory=LoggedInFactory
) )
config.add_route("settings_filters", "/settings/filters", factory=LoggedInFactory)
config.add_route( config.add_route(
'settings_filters', '/settings/filters', factory=LoggedInFactory)
config.add_route(
'settings_password_change',
'/settings/password_change',
factory=LoggedInFactory,
"settings_password_change", "/settings/password_change", factory=LoggedInFactory
) )
config.add_route('invite', '/invite', factory=LoggedInFactory)
config.add_route("invite", "/invite", factory=LoggedInFactory)
# Route to expose metrics to Prometheus # Route to expose metrics to Prometheus
config.add_route('metrics', '/metrics')
config.add_route("metrics", "/metrics")
# Route for Stripe donation processing page (POSTed to from docs site) # Route for Stripe donation processing page (POSTed to from docs site)
config.add_route('donate_stripe', '/donate_stripe')
config.add_route("donate_stripe", "/donate_stripe")
add_intercooler_routes(config) add_intercooler_routes(config)
def add_intercooler_routes(config: Configurator) -> None: def add_intercooler_routes(config: Configurator) -> None:
"""Set up all routes for the (internal-use) Intercooler API endpoints.""" """Set up all routes for the (internal-use) Intercooler API endpoints."""
def add_ic_route(name: str, path: str, **kwargs: Any) -> None: def add_ic_route(name: str, path: str, **kwargs: Any) -> None:
"""Add route with intercooler name prefix, base path, header check.""" """Add route with intercooler name prefix, base path, header check."""
name = 'ic_' + name
path = '/api/web' + path
config.add_route(
name,
path,
header='X-IC-Request:true',
**kwargs)
name = "ic_" + name
path = "/api/web" + path
config.add_route(name, path, header="X-IC-Request:true", **kwargs)
add_ic_route( add_ic_route(
'group_subscribe',
'/group/{group_path}/subscribe',
factory=group_by_path,
"group_subscribe", "/group/{group_path}/subscribe", factory=group_by_path
) )
add_ic_route( add_ic_route(
'group_user_settings',
'/group/{group_path}/user_settings',
"group_user_settings",
"/group/{group_path}/user_settings",
factory=group_by_path, factory=group_by_path,
) )
add_ic_route('topic', '/topics/{topic_id36}', factory=topic_by_id36)
add_ic_route("topic", "/topics/{topic_id36}", factory=topic_by_id36)
add_ic_route( add_ic_route(
'topic_comments',
'/topics/{topic_id36}/comments',
factory=topic_by_id36,
)
add_ic_route(
'topic_group', '/topics/{topic_id36}/group', factory=topic_by_id36)
add_ic_route(
'topic_lock', '/topics/{topic_id36}/lock', factory=topic_by_id36)
add_ic_route(
'topic_title', '/topics/{topic_id36}/title', factory=topic_by_id36)
add_ic_route(
'topic_vote', '/topics/{topic_id36}/vote', factory=topic_by_id36)
add_ic_route(
'topic_tags',
'/topics/{topic_id36}/tags',
factory=topic_by_id36,
"topic_comments", "/topics/{topic_id36}/comments", factory=topic_by_id36
) )
add_ic_route("topic_group", "/topics/{topic_id36}/group", factory=topic_by_id36)
add_ic_route("topic_lock", "/topics/{topic_id36}/lock", factory=topic_by_id36)
add_ic_route("topic_title", "/topics/{topic_id36}/title", factory=topic_by_id36)
add_ic_route("topic_vote", "/topics/{topic_id36}/vote", factory=topic_by_id36)
add_ic_route("topic_tags", "/topics/{topic_id36}/tags", factory=topic_by_id36)
add_ic_route("comment", "/comments/{comment_id36}", factory=comment_by_id36)
add_ic_route( add_ic_route(
'comment', '/comments/{comment_id36}', factory=comment_by_id36)
add_ic_route(
'comment_replies',
'/comments/{comment_id36}/replies',
factory=comment_by_id36,
"comment_replies", "/comments/{comment_id36}/replies", factory=comment_by_id36
) )
add_ic_route( add_ic_route(
'comment_vote',
'/comments/{comment_id36}/vote',
factory=comment_by_id36,
"comment_vote", "/comments/{comment_id36}/vote", factory=comment_by_id36
) )
add_ic_route( add_ic_route(
'comment_tag',
'/comments/{comment_id36}/tags/{name}',
factory=comment_by_id36,
"comment_tag", "/comments/{comment_id36}/tags/{name}", factory=comment_by_id36
) )
add_ic_route( add_ic_route(
'comment_mark_read',
'/comments/{comment_id36}/mark_read',
"comment_mark_read",
"/comments/{comment_id36}/mark_read",
factory=notification_by_comment_id36, factory=notification_by_comment_id36,
) )
add_ic_route( add_ic_route(
'message_conversation_replies',
'/messages/conversations/{conversation_id36}/replies',
"message_conversation_replies",
"/messages/conversations/{conversation_id36}/replies",
factory=message_conversation_by_id36, factory=message_conversation_by_id36,
) )
add_ic_route('user', '/user/{username}', factory=user_by_username)
add_ic_route("user", "/user/{username}", factory=user_by_username)
add_ic_route( add_ic_route(
'user_filtered_topic_tags',
'/user/{username}/filtered_topic_tags',
"user_filtered_topic_tags",
"/user/{username}/filtered_topic_tags",
factory=user_by_username, factory=user_by_username,
) )
add_ic_route( add_ic_route(
'user_invite_code',
'/user/{username}/invite_code',
factory=user_by_username,
"user_invite_code", "/user/{username}/invite_code", factory=user_by_username
) )
add_ic_route( add_ic_route(
'user_default_listing_options',
'/user/{username}/default_listing_options',
"user_default_listing_options",
"/user/{username}/default_listing_options",
factory=user_by_username, factory=user_by_username,
) )
@ -196,7 +154,7 @@ class LoggedInFactory:
checking access to a specific resource (such as a topic or message). checking access to a specific resource (such as a topic or message).
""" """
__acl__ = ((Allow, Authenticated, 'view'),)
__acl__ = ((Allow, Authenticated, "view"),)
def __init__(self, request: Request) -> None: def __init__(self, request: Request) -> None:
"""Initialize - no-op, but needs to take the request as an arg.""" """Initialize - no-op, but needs to take the request as an arg."""

45
tildes/tildes/schemas/fields.py

@ -16,12 +16,7 @@ from tildes.lib.string import simplify_string
class Enum(Field): class Enum(Field):
"""Field for a native Python Enum (or subclasses).""" """Field for a native Python Enum (or subclasses)."""
def __init__(
self,
enum_class: Type = None,
*args: Any,
**kwargs: Any,
) -> None:
def __init__(self, enum_class: Type = None, *args: Any, **kwargs: Any) -> None:
"""Initialize the field with an optional enum class.""" """Initialize the field with an optional enum class."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._enum_class = enum_class self._enum_class = enum_class
@ -33,12 +28,12 @@ class Enum(Field):
def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum: def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum:
"""Deserialize a string to the enum member with that name.""" """Deserialize a string to the enum member with that name."""
if not self._enum_class: if not self._enum_class:
raise ValidationError('Cannot deserialize with no enum class.')
raise ValidationError("Cannot deserialize with no enum class.")
try: try:
return self._enum_class[value.upper()] return self._enum_class[value.upper()]
except KeyError: except KeyError:
raise ValidationError('Invalid enum member')
raise ValidationError("Invalid enum member")
class ID36(String): class ID36(String):
@ -56,25 +51,19 @@ class ShortTimePeriod(Field):
""" """
def _deserialize( def _deserialize(
self,
value: str,
attr: str,
data: dict,
self, value: str, attr: str, data: dict
) -> Optional[SimpleHoursPeriod]: ) -> Optional[SimpleHoursPeriod]:
"""Deserialize to a SimpleHoursPeriod object.""" """Deserialize to a SimpleHoursPeriod object."""
if value == 'all':
if value == "all":
return None return None
try: try:
return SimpleHoursPeriod.from_short_form(value) return SimpleHoursPeriod.from_short_form(value)
except ValueError: except ValueError:
raise ValidationError('Invalid time period')
raise ValidationError("Invalid time period")
def _serialize( def _serialize(
self,
value: Optional[SimpleHoursPeriod],
attr: str,
obj: object,
self, value: Optional[SimpleHoursPeriod], attr: str, obj: object
) -> Optional[str]: ) -> Optional[str]:
"""Serialize the value to the "short form" string.""" """Serialize the value to the "short form" string."""
if not value: if not value:
@ -100,11 +89,11 @@ class Markdown(Field):
super()._validate(value) super()._validate(value)
if value.isspace(): if value.isspace():
raise ValidationError('Cannot be entirely whitespace.')
raise ValidationError("Cannot be entirely whitespace.")
def _deserialize(self, value: str, attr: str, data: dict) -> str: def _deserialize(self, value: str, attr: str, data: dict) -> str:
"""Deserialize the string, removing carriage returns in the process.""" """Deserialize the string, removing carriage returns in the process."""
value = value.replace('\r', '')
value = value.replace("\r", "")
return value return value
@ -145,23 +134,13 @@ class SimpleString(Field):
class Ltree(Field): class Ltree(Field):
"""Field for postgresql ltree type.""" """Field for postgresql ltree type."""
def _serialize(
self,
value: sqlalchemy_utils.Ltree,
attr: str,
obj: object,
) -> str:
def _serialize(self, value: sqlalchemy_utils.Ltree, attr: str, obj: object) -> str:
"""Serialize the Ltree value - use the (string) path.""" """Serialize the Ltree value - use the (string) path."""
return value.path return value.path
def _deserialize(
self,
value: str,
attr: str,
data: dict,
) -> sqlalchemy_utils.Ltree:
def _deserialize(self, value: str, attr: str, data: dict) -> sqlalchemy_utils.Ltree:
"""Deserialize a string path to an Ltree object.""" """Deserialize a string path to an Ltree object."""
try: try:
return sqlalchemy_utils.Ltree(value) return sqlalchemy_utils.Ltree(value)
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValidationError('Invalid path')
raise ValidationError("Invalid path")

27
tildes/tildes/schemas/group.py

@ -15,11 +15,13 @@ from tildes.schemas.fields import Ltree, SimpleString
# - must end with a number or lowercase letter # - must end with a number or lowercase letter
# - the middle can contain numbers, lowercase letters, and underscores # - the middle can contain numbers, lowercase letters, and underscores
# Note: this regex does not contain any length checks, must be done separately # Note: this regex does not contain any length checks, must be done separately
# fmt: off
GROUP_PATH_ELEMENT_VALID_REGEX = re.compile( GROUP_PATH_ELEMENT_VALID_REGEX = re.compile(
'^[a-z0-9]' # start
'([a-z0-9_]*' # middle
'[a-z0-9])?$', # end
"^[a-z0-9]" # start
"([a-z0-9_]*" # middle
"[a-z0-9])?$" # end
) )
# fmt: on
SHORT_DESCRIPTION_MAX_LENGTH = 200 SHORT_DESCRIPTION_MAX_LENGTH = 200
@ -27,21 +29,20 @@ SHORT_DESCRIPTION_MAX_LENGTH = 200
class GroupSchema(Schema): class GroupSchema(Schema):
"""Marshmallow schema for groups.""" """Marshmallow schema for groups."""
path = Ltree(required=True, load_from='group_path')
path = Ltree(required=True, load_from="group_path")
created_time = DateTime(dump_only=True) created_time = DateTime(dump_only=True)
short_description = SimpleString( short_description = SimpleString(
max_length=SHORT_DESCRIPTION_MAX_LENGTH,
allow_none=True,
max_length=SHORT_DESCRIPTION_MAX_LENGTH, allow_none=True
) )
@pre_load @pre_load
def prepare_path(self, data: dict) -> dict: def prepare_path(self, data: dict) -> dict:
"""Prepare the path value before it's validated.""" """Prepare the path value before it's validated."""
if not self.context.get('fix_path_capitalization'):
if not self.context.get("fix_path_capitalization"):
return data return data
# path can also be loaded from group_path, so we need to check both # path can also be loaded from group_path, so we need to check both
keys = ('path', 'group_path')
keys = ("path", "group_path")
for key in keys: for key in keys:
if key in data and isinstance(data[key], str): if key in data and isinstance(data[key], str):
@ -49,17 +50,17 @@ class GroupSchema(Schema):
return data return data
@validates('path')
@validates("path")
def validate_path(self, value: sqlalchemy_utils.Ltree) -> None: def validate_path(self, value: sqlalchemy_utils.Ltree) -> None:
"""Validate the path field, raising an error if an issue exists.""" """Validate the path field, raising an error if an issue exists."""
# check each element for length and against validity regex # check each element for length and against validity regex
path_elements = value.path.split('.')
path_elements = value.path.split(".")
for element in path_elements: for element in path_elements:
if len(element) > 256: if len(element) > 256:
raise ValidationError('Path element %s is too long' % element)
raise ValidationError("Path element %s is too long" % element)
if not GROUP_PATH_ELEMENT_VALID_REGEX.match(element): if not GROUP_PATH_ELEMENT_VALID_REGEX.match(element):
raise ValidationError('Path element %s is invalid' % element)
raise ValidationError("Path element %s is invalid" % element)
class Meta: class Meta:
"""Always use strict checking so error handlers are invoked.""" """Always use strict checking so error handlers are invoked."""
@ -71,7 +72,7 @@ def is_valid_group_path(path: str) -> bool:
"""Return whether the group path is valid or not.""" """Return whether the group path is valid or not."""
schema = GroupSchema(partial=True) schema = GroupSchema(partial=True)
try: try:
schema.validate({'path': path})
schema.validate({"path": path})
except ValidationError: except ValidationError:
return False return False

58
tildes/tildes/schemas/topic.py

@ -4,13 +4,7 @@ import re
import typing import typing
from urllib.parse import urlparse from urllib.parse import urlparse
from marshmallow import (
pre_load,
Schema,
validates,
validates_schema,
ValidationError,
)
from marshmallow import pre_load, Schema, validates, validates_schema, ValidationError
from marshmallow.fields import DateTime, List, Nested, String, URL from marshmallow.fields import DateTime, List, Nested, String, URL
import sqlalchemy_utils import sqlalchemy_utils
@ -29,7 +23,7 @@ class TopicSchema(Schema):
topic_type = Enum(dump_only=True) topic_type = Enum(dump_only=True)
markdown = Markdown(allow_none=True) markdown = Markdown(allow_none=True)
rendered_html = String(dump_only=True) rendered_html = String(dump_only=True)
link = URL(schemes={'http', 'https'}, allow_none=True)
link = URL(schemes={"http", "https"}, allow_none=True)
created_time = DateTime(dump_only=True) created_time = DateTime(dump_only=True)
tags = List(Ltree()) tags = List(Ltree())
@ -39,22 +33,22 @@ class TopicSchema(Schema):
@pre_load @pre_load
def prepare_tags(self, data: dict) -> dict: def prepare_tags(self, data: dict) -> dict:
"""Prepare the tags before they're validated.""" """Prepare the tags before they're validated."""
if 'tags' not in data:
if "tags" not in data:
return data return data
tags: typing.List[str] = [] tags: typing.List[str] = []
for tag in data['tags']:
for tag in data["tags"]:
tag = tag.lower() tag = tag.lower()
# replace spaces with underscores # replace spaces with underscores
tag = tag.replace(' ', '_')
tag = tag.replace(" ", "_")
# remove any consecutive underscores # remove any consecutive underscores
tag = re.sub('_{2,}', '_', tag)
tag = re.sub("_{2,}", "_", tag)
# remove any leading/trailing underscores # remove any leading/trailing underscores
tag = tag.strip('_')
tag = tag.strip("_")
# drop any empty tags # drop any empty tags
if not tag or tag.isspace(): if not tag or tag.isspace():
@ -66,15 +60,12 @@ class TopicSchema(Schema):
tags.append(tag) tags.append(tag)
data['tags'] = tags
data["tags"] = tags
return data return data
@validates('tags')
def validate_tags(
self,
value: typing.List[sqlalchemy_utils.Ltree],
) -> None:
@validates("tags")
def validate_tags(self, value: typing.List[sqlalchemy_utils.Ltree]) -> None:
"""Validate the tags field, raising an error if an issue exists. """Validate the tags field, raising an error if an issue exists.
Note that tags are validated by ensuring that each tag would be a valid Note that tags are validated by ensuring that each tag would be a valid
@ -86,52 +77,51 @@ class TopicSchema(Schema):
group_schema = GroupSchema(partial=True) group_schema = GroupSchema(partial=True)
for tag in value: for tag in value:
try: try:
group_schema.validate({'path': tag})
group_schema.validate({"path": tag})
except ValidationError: except ValidationError:
raise ValidationError('Tag %s is invalid' % tag)
raise ValidationError("Tag %s is invalid" % tag)
@pre_load @pre_load
def prepare_markdown(self, data: dict) -> dict: def prepare_markdown(self, data: dict) -> dict:
"""Prepare the markdown value before it's validated.""" """Prepare the markdown value before it's validated."""
if 'markdown' not in data:
if "markdown" not in data:
return data return data
# if the value is empty, convert it to None # if the value is empty, convert it to None
if not data['markdown'] or data['markdown'].isspace():
data['markdown'] = None
if not data["markdown"] or data["markdown"].isspace():
data["markdown"] = None
return data return data
@pre_load @pre_load
def prepare_link(self, data: dict) -> dict: def prepare_link(self, data: dict) -> dict:
"""Prepare the link value before it's validated.""" """Prepare the link value before it's validated."""
if 'link' not in data:
if "link" not in data:
return data return data
# if the value is empty, convert it to None # if the value is empty, convert it to None
if not data['link'] or data['link'].isspace():
data['link'] = None
if not data["link"] or data["link"].isspace():
data["link"] = None
return data return data
# prepend http:// to the link if it doesn't have a scheme # prepend http:// to the link if it doesn't have a scheme
parsed = urlparse(data['link'])
parsed = urlparse(data["link"])
if not parsed.scheme: if not parsed.scheme:
data['link'] = 'http://' + data['link']
data["link"] = "http://" + data["link"]
return data return data
@validates_schema @validates_schema
def link_or_markdown(self, data: dict) -> None: def link_or_markdown(self, data: dict) -> None:
"""Fail validation unless at least one of link or markdown were set.""" """Fail validation unless at least one of link or markdown were set."""
if 'link' not in data and 'markdown' not in data:
if "link" not in data and "markdown" not in data:
return return
link = data.get('link')
markdown = data.get('markdown')
link = data.get("link")
markdown = data.get("markdown")
if not (markdown or link): if not (markdown or link):
raise ValidationError(
'Topics must have either markdown or a link.')
raise ValidationError("Topics must have either markdown or a link.")
class Meta: class Meta:
"""Always use strict checking so error handlers are invoked.""" """Always use strict checking so error handlers are invoked."""

13
tildes/tildes/schemas/topic_listing.py

@ -17,25 +17,22 @@ class TopicListingSchema(Schema):
period = ShortTimePeriod(allow_none=True) period = ShortTimePeriod(allow_none=True)
after = ID36() after = ID36()
before = ID36() before = ID36()
per_page = Integer(
validate=Range(min=1, max=100),
missing=DEFAULT_TOPICS_PER_PAGE,
)
rank_start = Integer(load_from='n', validate=Range(min=1), missing=None)
per_page = Integer(validate=Range(min=1, max=100), missing=DEFAULT_TOPICS_PER_PAGE)
rank_start = Integer(load_from="n", validate=Range(min=1), missing=None)
tag = Ltree(missing=None) tag = Ltree(missing=None)
unfiltered = Boolean(missing=False) unfiltered = Boolean(missing=False)
@validates_schema @validates_schema
def either_after_or_before(self, data: dict) -> None: def either_after_or_before(self, data: dict) -> None:
"""Fail validation if both after and before were specified.""" """Fail validation if both after and before were specified."""
if data.get('after') and data.get('before'):
if data.get("after") and data.get("before"):
raise ValidationError("Can't specify both after and before.") raise ValidationError("Can't specify both after and before.")
@pre_load @pre_load
def reset_rank_start_on_first_page(self, data: dict) -> dict: def reset_rank_start_on_first_page(self, data: dict) -> dict:
"""Reset rank_start to 1 if this is a first page (no before/after).""" """Reset rank_start to 1 if this is a first page (no before/after)."""
if not (data.get('before') or data.get('after')):
data['rank_start'] = 1
if not (data.get("before") or data.get("after")):
data["rank_start"] = 1
return data return data

46
tildes/tildes/schemas/user.py

@ -2,13 +2,7 @@
import re import re
from marshmallow import (
post_dump,
pre_load,
Schema,
validates,
validates_schema,
)
from marshmallow import post_dump, pre_load, Schema, validates, validates_schema
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from marshmallow.fields import Boolean, DateTime, Email, String from marshmallow.fields import Boolean, DateTime, Email, String
from marshmallow.validate import Length, Regexp from marshmallow.validate import Length, Regexp
@ -26,11 +20,14 @@ USERNAME_MAX_LENGTH = 20
# more than one underscore/dash consecutively (this includes both "_-" and # more than one underscore/dash consecutively (this includes both "_-" and
# "-_" sequences being invalid) # "-_" sequences being invalid)
# Note: this regex does not contain any length checks, must be done separately # Note: this regex does not contain any length checks, must be done separately
# fmt: off
USERNAME_VALID_REGEX = re.compile( USERNAME_VALID_REGEX = re.compile(
"^[a-z0-9]" # start "^[a-z0-9]" # start
"([a-z0-9]|[_-](?![_-]))*" # middle "([a-z0-9]|[_-](?![_-]))*" # middle
"[a-z0-9]$", # end "[a-z0-9]$", # end
re.IGNORECASE)
re.IGNORECASE,
)
# fmt: on
PASSWORD_MIN_LENGTH = 8 PASSWORD_MIN_LENGTH = 8
@ -48,29 +45,26 @@ class UserSchema(Schema):
required=True, required=True,
) )
password = String( password = String(
validate=Length(min=PASSWORD_MIN_LENGTH),
required=True,
load_only=True,
validate=Length(min=PASSWORD_MIN_LENGTH), required=True, load_only=True
) )
email_address = Email(allow_none=True, load_only=True) email_address = Email(allow_none=True, load_only=True)
email_address_note = String(
validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
email_address_note = String(validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
created_time = DateTime(dump_only=True) created_time = DateTime(dump_only=True)
track_comment_visits = Boolean() track_comment_visits = Boolean()
@post_dump @post_dump
def anonymize_username(self, data: dict) -> dict: def anonymize_username(self, data: dict) -> dict:
"""Hide the username if the dumping context specifies to do so.""" """Hide the username if the dumping context specifies to do so."""
if 'username' in data and self.context.get('hide_username'):
data['username'] = '<unknown>'
if "username" in data and self.context.get("hide_username"):
data["username"] = "<unknown>"
return data return data
@validates_schema @validates_schema
def username_pass_not_substrings(self, data: dict) -> None: def username_pass_not_substrings(self, data: dict) -> None:
"""Ensure the username isn't in the password and vice versa.""" """Ensure the username isn't in the password and vice versa."""
username = data.get('username')
password = data.get('password')
username = data.get("username")
password = data.get("password")
if not (username and password): if not (username and password):
return return
@ -78,32 +72,32 @@ class UserSchema(Schema):
password = password.lower() password = password.lower()
if username in password: if username in password:
raise ValidationError('Password cannot contain username')
raise ValidationError("Password cannot contain username")
if password in username: if password in username:
raise ValidationError('Username cannot contain password')
raise ValidationError("Username cannot contain password")
@validates('password')
@validates("password")
def password_not_breached(self, value: str) -> None: def password_not_breached(self, value: str) -> None:
"""Validate that the password is not in the breached-passwords list. """Validate that the password is not in the breached-passwords list.
Requires check_breached_passwords be True in the schema's context. Requires check_breached_passwords be True in the schema's context.
""" """
if not self.context.get('check_breached_passwords'):
if not self.context.get("check_breached_passwords"):
return return
if is_breached_password(value): if is_breached_password(value):
raise ValidationError('That password exists in a data breach')
raise ValidationError("That password exists in a data breach")
@pre_load @pre_load
def prepare_email_address(self, data: dict) -> dict: def prepare_email_address(self, data: dict) -> dict:
"""Prepare the email address value before it's validated.""" """Prepare the email address value before it's validated."""
if 'email_address' not in data:
if "email_address" not in data:
return data return data
# if the value is empty, convert it to None # if the value is empty, convert it to None
if not data['email_address'] or data['email_address'].isspace():
data['email_address'] = None
if not data["email_address"] or data["email_address"].isspace():
data["email_address"] = None
return data return data
@ -122,7 +116,7 @@ def is_valid_username(username: str) -> bool:
""" """
schema = UserSchema(partial=True) schema = UserSchema(partial=True)
try: try:
schema.validate({'username': username})
schema.validate({"username": username})
except ValidationError: except ValidationError:
return False return False

2
tildes/tildes/views/__init__.py

@ -10,4 +10,4 @@ IC_NOOP_404 = Response(status_int=404)
# Because of the above, in order to deliberately cause Intercooler to replace # Because of the above, in order to deliberately cause Intercooler to replace
# an element with whitespace, the response needs to contain at least two spaces # an element with whitespace, the response needs to contain at least two spaces
IC_EMPTY = Response(' ')
IC_EMPTY = Response(" ")

2
tildes/tildes/views/api/v0/group.py

@ -6,7 +6,7 @@ from tildes.api import APIv0
from tildes.resources.group import group_by_path from tildes.resources.group import group_by_path
ONE = APIv0(name='group', path='/groups/{group_path}', factory=group_by_path)
ONE = APIv0(name="group", path="/groups/{group_path}", factory=group_by_path)
@ONE.get() @ONE.get()

4
tildes/tildes/views/api/v0/topic.py

@ -7,9 +7,7 @@ from tildes.resources.topic import topic_by_id36
ONE = APIv0( ONE = APIv0(
name='topic',
path='/groups/{group_path}/topics/{topic_id36}',
factory=topic_by_id36,
name="topic", path="/groups/{group_path}/topics/{topic_id36}", factory=topic_by_id36
) )

2
tildes/tildes/views/api/v0/user.py

@ -6,7 +6,7 @@ from tildes.api import APIv0
from tildes.resources.user import user_by_username from tildes.resources.user import user_by_username
ONE = APIv0(name='user', path='/users/{username}', factory=user_by_username)
ONE = APIv0(name="user", path="/users/{username}", factory=user_by_username)
@ONE.get() @ONE.get()

167
tildes/tildes/views/api/web/comment.py

@ -12,22 +12,14 @@ from zope.sqlalchemy import mark_changed
from tildes.enums import CommentNotificationType, CommentTagOption from tildes.enums import CommentNotificationType, CommentTagOption
from tildes.lib.datetime import utc_now from tildes.lib.datetime import utc_now
from tildes.models.comment import (
Comment,
CommentNotification,
CommentTag,
CommentVote,
)
from tildes.models.comment import Comment, CommentNotification, CommentTag, CommentVote
from tildes.models.topic import TopicVisit from tildes.models.topic import TopicVisit
from tildes.schemas.comment import CommentSchema, CommentTagSchema from tildes.schemas.comment import CommentSchema, CommentTagSchema
from tildes.views import IC_NOOP from tildes.views import IC_NOOP
from tildes.views.decorators import ic_view_config from tildes.views.decorators import ic_view_config
def _increment_topic_comments_seen(
request: Request,
comment: Comment,
) -> None:
def _increment_topic_comments_seen(request: Request, comment: Comment) -> None:
"""Increment the number of comments in a topic the user has viewed. """Increment the number of comments in a topic the user has viewed.
If the user has the "track comment visits" feature enabled, we want to If the user has the "track comment visits" feature enabled, we want to
@ -50,7 +42,7 @@ def _increment_topic_comments_seen(
) )
.on_conflict_do_update( .on_conflict_do_update(
constraint=TopicVisit.__table__.primary_key, constraint=TopicVisit.__table__.primary_key,
set_={'num_comments': TopicVisit.num_comments + 1},
set_={"num_comments": TopicVisit.num_comments + 1},
where=TopicVisit.visit_time < comment.created_time, where=TopicVisit.visit_time < comment.created_time,
) )
) )
@ -60,28 +52,22 @@ def _increment_topic_comments_seen(
@ic_view_config( @ic_view_config(
route_name='topic_comments',
request_method='POST',
renderer='single_comment.jinja2',
permission='comment',
route_name="topic_comments",
request_method="POST",
renderer="single_comment.jinja2",
permission="comment",
) )
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def post_toplevel_comment(request: Request, markdown: str) -> dict: def post_toplevel_comment(request: Request, markdown: str) -> dict:
"""Post a new top-level comment on a topic with Intercooler.""" """Post a new top-level comment on a topic with Intercooler."""
topic = request.context topic = request.context
new_comment = Comment(
topic=topic,
author=request.user,
markdown=markdown,
)
new_comment = Comment(topic=topic, author=request.user, markdown=markdown)
request.db_session.add(new_comment) request.db_session.add(new_comment)
if topic.user != request.user and not topic.is_deleted: if topic.user != request.user and not topic.is_deleted:
notification = CommentNotification( notification = CommentNotification(
topic.user,
new_comment,
CommentNotificationType.TOPIC_REPLY,
topic.user, new_comment, CommentNotificationType.TOPIC_REPLY
) )
request.db_session.add(notification) request.db_session.add(notification)
@ -95,16 +81,16 @@ def post_toplevel_comment(request: Request, markdown: str) -> dict:
.one() .one()
) )
return {'comment': new_comment, 'topic': topic}
return {"comment": new_comment, "topic": topic}
@ic_view_config( @ic_view_config(
route_name='comment_replies',
request_method='POST',
renderer='single_comment.jinja2',
permission='reply',
route_name="comment_replies",
request_method="POST",
renderer="single_comment.jinja2",
permission="reply",
) )
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def post_comment_reply(request: Request, markdown: str) -> dict: def post_comment_reply(request: Request, markdown: str) -> dict:
"""Post a reply to a comment with Intercooler.""" """Post a reply to a comment with Intercooler."""
parent_comment = request.context parent_comment = request.context
@ -118,9 +104,7 @@ def post_comment_reply(request: Request, markdown: str) -> dict:
if parent_comment.user != request.user: if parent_comment.user != request.user:
notification = CommentNotification( notification = CommentNotification(
parent_comment.user,
new_comment,
CommentNotificationType.COMMENT_REPLY,
parent_comment.user, new_comment, CommentNotificationType.COMMENT_REPLY
) )
request.db_session.add(notification) request.db_session.add(notification)
@ -134,67 +118,67 @@ def post_comment_reply(request: Request, markdown: str) -> dict:
.one() .one()
) )
return {'comment': new_comment}
return {"comment": new_comment}
@ic_view_config( @ic_view_config(
route_name='comment',
request_method='GET',
renderer='comment_contents.jinja2',
permission='view',
route_name="comment",
request_method="GET",
renderer="comment_contents.jinja2",
permission="view",
) )
def get_comment_contents(request: Request) -> dict: def get_comment_contents(request: Request) -> dict:
"""Get a comment's body with Intercooler.""" """Get a comment's body with Intercooler."""
return {'comment': request.context}
return {"comment": request.context}
@ic_view_config( @ic_view_config(
route_name='comment',
request_method='GET',
request_param='ic-trigger-name=edit',
renderer='comment_edit.jinja2',
permission='edit',
route_name="comment",
request_method="GET",
request_param="ic-trigger-name=edit",
renderer="comment_edit.jinja2",
permission="edit",
) )
def get_comment_edit(request: Request) -> dict: def get_comment_edit(request: Request) -> dict:
"""Get the edit form for a comment with Intercooler.""" """Get the edit form for a comment with Intercooler."""
return {'comment': request.context}
return {"comment": request.context}
@ic_view_config( @ic_view_config(
route_name='comment',
request_method='PATCH',
renderer='comment_contents.jinja2',
permission='edit',
route_name="comment",
request_method="PATCH",
renderer="comment_contents.jinja2",
permission="edit",
) )
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def patch_comment(request: Request, markdown: str) -> dict: def patch_comment(request: Request, markdown: str) -> dict:
"""Update a comment with Intercooler.""" """Update a comment with Intercooler."""
comment = request.context comment = request.context
comment.markdown = markdown comment.markdown = markdown
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment',
request_method='DELETE',
renderer='comment_contents.jinja2',
permission='delete',
route_name="comment",
request_method="DELETE",
renderer="comment_contents.jinja2",
permission="delete",
) )
def delete_comment(request: Request) -> dict: def delete_comment(request: Request) -> dict:
"""Delete a comment with Intercooler.""" """Delete a comment with Intercooler."""
comment = request.context comment = request.context
comment.is_deleted = True comment.is_deleted = True
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment_vote',
request_method='PUT',
permission='vote',
renderer='comment_contents.jinja2',
route_name="comment_vote",
request_method="PUT",
permission="vote",
renderer="comment_contents.jinja2",
) )
def put_vote_comment(request: Request) -> dict: def put_vote_comment(request: Request) -> dict:
"""Vote on a comment with Intercooler.""" """Vote on a comment with Intercooler."""
@ -222,22 +206,21 @@ def put_vote_comment(request: Request) -> dict:
.one() .one()
) )
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment_vote',
request_method='DELETE',
permission='vote',
renderer='comment_contents.jinja2',
route_name="comment_vote",
request_method="DELETE",
permission="vote",
renderer="comment_contents.jinja2",
) )
def delete_vote_comment(request: Request) -> dict: def delete_vote_comment(request: Request) -> dict:
"""Remove the user's vote from a comment with Intercooler.""" """Remove the user's vote from a comment with Intercooler."""
comment = request.context comment = request.context
request.query(CommentVote).filter( request.query(CommentVote).filter(
CommentVote.comment == comment,
CommentVote.user == request.user,
CommentVote.comment == comment, CommentVote.user == request.user
).delete(synchronize_session=False) ).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute # manually commit the transaction so triggers will execute
@ -251,16 +234,16 @@ def delete_vote_comment(request: Request) -> dict:
.one() .one()
) )
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment_tag',
request_method='PUT',
permission='tag',
renderer='comment_contents.jinja2',
route_name="comment_tag",
request_method="PUT",
permission="tag",
renderer="comment_contents.jinja2",
) )
@use_kwargs(CommentTagSchema(only=('name',)), locations=('matchdict',))
@use_kwargs(CommentTagSchema(only=("name",)), locations=("matchdict",))
def put_tag_comment(request: Request, name: CommentTagOption) -> Response: def put_tag_comment(request: Request, name: CommentTagOption) -> Response:
"""Add a tag to a comment.""" """Add a tag to a comment."""
comment = request.context comment = request.context
@ -286,16 +269,16 @@ def put_tag_comment(request: Request, name: CommentTagOption) -> Response:
.one() .one()
) )
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment_tag',
request_method='DELETE',
permission='tag',
renderer='comment_contents.jinja2',
route_name="comment_tag",
request_method="DELETE",
permission="tag",
renderer="comment_contents.jinja2",
) )
@use_kwargs(CommentTagSchema(only=('name',)), locations=('matchdict',))
@use_kwargs(CommentTagSchema(only=("name",)), locations=("matchdict",))
def delete_tag_comment(request: Request, name: CommentTagOption) -> Response: def delete_tag_comment(request: Request, name: CommentTagOption) -> Response:
"""Remove a tag (that the user previously added) from a comment.""" """Remove a tag (that the user previously added) from a comment."""
comment = request.context comment = request.context
@ -316,19 +299,14 @@ def delete_tag_comment(request: Request, name: CommentTagOption) -> Response:
.one() .one()
) )
return {'comment': comment}
return {"comment": comment}
@ic_view_config( @ic_view_config(
route_name='comment_mark_read',
request_method='PUT',
permission='mark_read',
route_name="comment_mark_read", request_method="PUT", permission="mark_read"
) )
@use_kwargs({'mark_all_previous': Boolean(missing=False)})
def put_mark_comments_read(
request: Request,
mark_all_previous: bool,
) -> Response:
@use_kwargs({"mark_all_previous": Boolean(missing=False)})
def put_mark_comments_read(request: Request, mark_all_previous: bool) -> Response:
"""Mark comment(s) read, clearing notifications. """Mark comment(s) read, clearing notifications.
The "main" notification (request.context) will always be marked read, and The "main" notification (request.context) will always be marked read, and
@ -339,7 +317,8 @@ def put_mark_comments_read(
if mark_all_previous: if mark_all_previous:
prev_notifications = ( prev_notifications = (
request.query(CommentNotification).filter(
request.query(CommentNotification)
.filter(
CommentNotification.user == request.user, CommentNotification.user == request.user,
CommentNotification.is_unread == True, # noqa CommentNotification.is_unread == True, # noqa
CommentNotification.created_time <= notification.created_time, CommentNotification.created_time <= notification.created_time,
@ -351,16 +330,14 @@ def put_mark_comments_read(
# sort the notifications by created_time of their comment so that the # sort the notifications by created_time of their comment so that the
# INSERT ... ON CONFLICT DO UPDATE statements work as expected # INSERT ... ON CONFLICT DO UPDATE statements work as expected
prev_notifications = sorted( prev_notifications = sorted(
prev_notifications, key=lambda c: c.comment.created_time)
prev_notifications, key=lambda c: c.comment.created_time
)
for comment_notification in prev_notifications: for comment_notification in prev_notifications:
comment_notification.is_unread = False comment_notification.is_unread = False
_increment_topic_comments_seen(
request,
comment_notification.comment
)
_increment_topic_comments_seen(request, comment_notification.comment)
return Response('Your comment notifications have been cleared.')
return Response("Your comment notifications have been cleared.")
notification.is_unread = False notification.is_unread = False
_increment_topic_comments_seen(request, notification.comment) _increment_topic_comments_seen(request, notification.comment)

21
tildes/tildes/views/api/web/exceptions.py

@ -18,7 +18,7 @@ from tildes.views.decorators import ic_view_config
def _422_response_with_errors(errors: Sequence[str]) -> Response: def _422_response_with_errors(errors: Sequence[str]) -> Response:
response = Response('\n'.join(errors))
response = Response("\n".join(errors))
response.status_int = 422 response.status_int = 422
return response return response
@ -44,9 +44,9 @@ def unprocessable_entity(request: Request) -> Response:
error_strings = [] error_strings = []
for field, errors in errors_by_field.items(): for field, errors in errors_by_field.items():
joined_errors = ' '.join(errors)
if field != '_schema':
error_strings.append(f'{field}: {joined_errors}')
joined_errors = " ".join(errors)
if field != "_schema":
error_strings.append(f"{field}: {joined_errors}")
else: else:
error_strings.append(joined_errors) error_strings.append(joined_errors)
@ -65,11 +65,11 @@ def httpnotfound(request: Request) -> Response:
response = request.exception response = request.exception
if request.matched_route.factory == comment_by_id36: if request.matched_route.factory == comment_by_id36:
response.text = 'Comment not found (or it was deleted)'
response.text = "Comment not found (or it was deleted)"
elif request.matched_route.factory == topic_by_id36: elif request.matched_route.factory == topic_by_id36:
response.text = 'Topic not found (or it was deleted)'
response.text = "Topic not found (or it was deleted)"
else: else:
response.text = 'Not found'
response.text = "Not found"
return response return response
@ -79,10 +79,9 @@ def httptoomanyrequests(request: Request) -> Response:
"""Update a 429 error to show wait time info in the response text.""" """Update a 429 error to show wait time info in the response text."""
response = request.exception response = request.exception
retry_seconds = request.exception.headers['Retry-After']
retry_seconds = request.exception.headers["Retry-After"]
response.text = ( response.text = (
'Rate limit exceeded. '
f'Please wait {retry_seconds} seconds before retrying.'
f"Rate limit exceeded. Please wait {retry_seconds} seconds before retrying."
) )
return response return response
@ -99,4 +98,4 @@ def httpfound(request: Request) -> Response:
exception view will convert a 302 into a 200 with that header so it works exception view will convert a 302 into a 200 with that header so it works
as a redirect for both standard requests as well as Intercooler ones. as a redirect for both standard requests as well as Intercooler ones.
""" """
return Response(headers={'X-IC-Redirect': request.exception.location})
return Response(headers={"X-IC-Redirect": request.exception.location})

41
tildes/tildes/views/api/web/group.py

@ -17,10 +17,10 @@ from tildes.views.decorators import ic_view_config
@ic_view_config( @ic_view_config(
route_name='group_subscribe',
request_method='PUT',
permission='subscribe',
renderer='group_subscription_box.jinja2',
route_name="group_subscribe",
request_method="PUT",
permission="subscribe",
renderer="group_subscription_box.jinja2",
) )
def put_subscribe_group(request: Request) -> dict: def put_subscribe_group(request: Request) -> dict:
"""Subscribe to a group with Intercooler.""" """Subscribe to a group with Intercooler."""
@ -48,22 +48,21 @@ def put_subscribe_group(request: Request) -> dict:
.one() .one()
) )
return {'group': group}
return {"group": group}
@ic_view_config( @ic_view_config(
route_name='group_subscribe',
request_method='DELETE',
permission='subscribe',
renderer='group_subscription_box.jinja2',
route_name="group_subscribe",
request_method="DELETE",
permission="subscribe",
renderer="group_subscription_box.jinja2",
) )
def delete_subscribe_group(request: Request) -> dict: def delete_subscribe_group(request: Request) -> dict:
"""Remove the user's subscription from a group with Intercooler.""" """Remove the user's subscription from a group with Intercooler."""
group = request.context group = request.context
request.query(GroupSubscription).filter( request.query(GroupSubscription).filter(
GroupSubscription.group == group,
GroupSubscription.user == request.user,
GroupSubscription.group == group, GroupSubscription.user == request.user
).delete(synchronize_session=False) ).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute # manually commit the transaction so triggers will execute
@ -77,27 +76,21 @@ def delete_subscribe_group(request: Request) -> dict:
.one() .one()
) )
return {'group': group}
return {"group": group}
@ic_view_config(
route_name='group_user_settings',
request_method='PATCH',
@ic_view_config(route_name="group_user_settings", request_method="PATCH")
@use_kwargs(
{"order": Enum(TopicSortOption), "period": ShortTimePeriod(allow_none=True)}
) )
@use_kwargs({
'order': Enum(TopicSortOption),
'period': ShortTimePeriod(allow_none=True),
})
def patch_group_user_settings( def patch_group_user_settings(
request: Request,
order: TopicSortOption,
period: Optional[ShortTimePeriod],
request: Request, order: TopicSortOption, period: Optional[ShortTimePeriod]
) -> dict: ) -> dict:
"""Set the user's default listing options.""" """Set the user's default listing options."""
if period: if period:
default_period = period.as_short_form() default_period = period.as_short_form()
else: else:
default_period = 'all'
default_period = "all"
statement = ( statement = (
insert(UserGroupSettings.__table__) insert(UserGroupSettings.__table__)
@ -109,7 +102,7 @@ def patch_group_user_settings(
) )
.on_conflict_do_update( .on_conflict_do_update(
constraint=UserGroupSettings.__table__.primary_key, constraint=UserGroupSettings.__table__.primary_key,
set_={'default_order': order, 'default_period': default_period},
set_={"default_order": order, "default_period": default_period},
) )
) )
request.db_session.execute(statement) request.db_session.execute(statement)

16
tildes/tildes/views/api/web/message.py

@ -9,19 +9,17 @@ from tildes.views.decorators import ic_view_config
@ic_view_config( @ic_view_config(
route_name='message_conversation_replies',
request_method='POST',
renderer='single_message.jinja2',
permission='reply',
route_name="message_conversation_replies",
request_method="POST",
renderer="single_message.jinja2",
permission="reply",
) )
@use_kwargs(MessageReplySchema(only=('markdown',)))
@use_kwargs(MessageReplySchema(only=("markdown",)))
def post_message_reply(request: Request, markdown: str) -> dict: def post_message_reply(request: Request, markdown: str) -> dict:
"""Post a reply to a message conversation with Intercooler.""" """Post a reply to a message conversation with Intercooler."""
conversation = request.context conversation = request.context
new_reply = MessageReply( new_reply = MessageReply(
conversation=conversation,
sender=request.user,
markdown=markdown,
conversation=conversation, sender=request.user, markdown=markdown
) )
request.db_session.add(new_reply) request.db_session.add(new_reply)
@ -35,4 +33,4 @@ def post_message_reply(request: Request, markdown: str) -> dict:
.one() .one()
) )
return {'message': new_reply}
return {"message": new_reply}

172
tildes/tildes/views/api/web/topic.py

@ -19,66 +19,63 @@ from tildes.views.decorators import ic_view_config
@ic_view_config( @ic_view_config(
route_name='topic',
request_method='GET',
request_param='ic-trigger-name=edit',
renderer='topic_edit.jinja2',
permission='edit',
route_name="topic",
request_method="GET",
request_param="ic-trigger-name=edit",
renderer="topic_edit.jinja2",
permission="edit",
) )
def get_topic_edit(request: Request) -> dict: def get_topic_edit(request: Request) -> dict:
"""Get the edit form for a topic with Intercooler.""" """Get the edit form for a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config( @ic_view_config(
route_name='topic',
request_method='GET',
renderer='topic_contents.jinja2',
permission='view',
route_name="topic",
request_method="GET",
renderer="topic_contents.jinja2",
permission="view",
) )
def get_topic_contents(request: Request) -> dict: def get_topic_contents(request: Request) -> dict:
"""Get a topic's body with Intercooler.""" """Get a topic's body with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config( @ic_view_config(
route_name='topic',
request_method='PATCH',
renderer='topic_contents.jinja2',
permission='edit',
route_name="topic",
request_method="PATCH",
renderer="topic_contents.jinja2",
permission="edit",
) )
@use_kwargs(TopicSchema(only=('markdown',)))
@use_kwargs(TopicSchema(only=("markdown",)))
def patch_topic(request: Request, markdown: str) -> dict: def patch_topic(request: Request, markdown: str) -> dict:
"""Update a topic with Intercooler.""" """Update a topic with Intercooler."""
topic = request.context topic = request.context
topic.markdown = markdown topic.markdown = markdown
return {'topic': topic}
return {"topic": topic}
@ic_view_config(
route_name='topic',
request_method='DELETE',
permission='delete',
)
@ic_view_config(route_name="topic", request_method="DELETE", permission="delete")
def delete_topic(request: Request) -> Response: def delete_topic(request: Request) -> Response:
"""Delete a topic with Intercooler and redirect to its group.""" """Delete a topic with Intercooler and redirect to its group."""
topic = request.context topic = request.context
topic.is_deleted = True topic.is_deleted = True
response = Response() response = Response()
response.headers['X-IC-Redirect'] = request.route_url(
'group', group_path=topic.group.path)
response.headers["X-IC-Redirect"] = request.route_url(
"group", group_path=topic.group.path
)
return response return response
@ic_view_config( @ic_view_config(
route_name='topic_vote',
request_method='PUT',
renderer='topic_voting.jinja2',
permission='vote',
route_name="topic_vote",
request_method="PUT",
renderer="topic_voting.jinja2",
permission="vote",
) )
def put_topic_vote(request: Request) -> Response: def put_topic_vote(request: Request) -> Response:
"""Vote on a topic with Intercooler.""" """Vote on a topic with Intercooler."""
@ -106,22 +103,21 @@ def put_topic_vote(request: Request) -> Response:
.one() .one()
) )
return {'topic': topic}
return {"topic": topic}
@ic_view_config( @ic_view_config(
route_name='topic_vote',
request_method='DELETE',
renderer='topic_voting.jinja2',
permission='vote',
route_name="topic_vote",
request_method="DELETE",
renderer="topic_voting.jinja2",
permission="vote",
) )
def delete_topic_vote(request: Request) -> Response: def delete_topic_vote(request: Request) -> Response:
"""Remove the user's vote from a topic with Intercooler.""" """Remove the user's vote from a topic with Intercooler."""
topic = request.context topic = request.context
request.query(TopicVote).filter( request.query(TopicVote).filter(
TopicVote.topic == topic,
TopicVote.user == request.user,
TopicVote.topic == topic, TopicVote.user == request.user
).delete(synchronize_session=False) ).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute # manually commit the transaction so triggers will execute
@ -135,34 +131,34 @@ def delete_topic_vote(request: Request) -> Response:
.one() .one()
) )
return {'topic': topic}
return {"topic": topic}
@ic_view_config( @ic_view_config(
route_name='topic_tags',
request_method='GET',
renderer='topic_tags_edit.jinja2',
permission='tag',
route_name="topic_tags",
request_method="GET",
renderer="topic_tags_edit.jinja2",
permission="tag",
) )
def get_topic_tags(request: Request) -> dict: def get_topic_tags(request: Request) -> dict:
"""Get the tagging form for a topic with Intercooler.""" """Get the tagging form for a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config( @ic_view_config(
route_name='topic_tags',
request_method='PUT',
renderer='topic_tags.jinja2',
permission='tag',
route_name="topic_tags",
request_method="PUT",
renderer="topic_tags.jinja2",
permission="tag",
) )
@use_kwargs({'tags': String()})
@use_kwargs({"tags": String()})
def put_tag_topic(request: Request, tags: str) -> dict: def put_tag_topic(request: Request, tags: str) -> dict:
"""Apply tags to a topic with Intercooler.""" """Apply tags to a topic with Intercooler."""
topic = request.context topic = request.context
if tags: if tags:
# split the tag string on commas # split the tag string on commas
new_tags = tags.split(',')
new_tags = tags.split(",")
else: else:
new_tags = [] new_tags = []
@ -171,7 +167,7 @@ def put_tag_topic(request: Request, tags: str) -> dict:
try: try:
topic.tags = new_tags topic.tags = new_tags
except ValidationError: except ValidationError:
raise ValidationError({'tags': ['Invalid tags']})
raise ValidationError({"tags": ["Invalid tags"]})
# if tags weren't changed, don't add a log entry or update page # if tags weren't changed, don't add a log entry or update page
if set(topic.tags) == set(old_tags): if set(topic.tags) == set(old_tags):
@ -182,42 +178,38 @@ def put_tag_topic(request: Request, tags: str) -> dict:
LogEventType.TOPIC_TAG, LogEventType.TOPIC_TAG,
request, request,
topic, topic,
info={'old': old_tags, 'new': topic.tags},
),
info={"old": old_tags, "new": topic.tags},
)
) )
return {'topic': topic}
return {"topic": topic}
@ic_view_config( @ic_view_config(
route_name='topic_group',
request_method='GET',
renderer='topic_group_edit.jinja2',
permission='move',
route_name="topic_group",
request_method="GET",
renderer="topic_group_edit.jinja2",
permission="move",
) )
def get_topic_group(request: Request) -> dict: def get_topic_group(request: Request) -> dict:
"""Get the form for moving a topic with Intercooler.""" """Get the form for moving a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config( @ic_view_config(
route_name='topic',
request_param='ic-trigger-name=topic-move',
request_method='PATCH',
permission='move',
route_name="topic",
request_param="ic-trigger-name=topic-move",
request_method="PATCH",
permission="move",
) )
@use_kwargs(GroupSchema(only=('path',)))
@use_kwargs(GroupSchema(only=("path",)))
def patch_move_topic(request: Request, path: str) -> dict: def patch_move_topic(request: Request, path: str) -> dict:
"""Move a topic to a different group with Intercooler.""" """Move a topic to a different group with Intercooler."""
topic = request.context topic = request.context
new_group = (
request.query(Group)
.filter(Group.path == path)
.one_or_none()
)
new_group = request.query(Group).filter(Group.path == path).one_or_none()
if not new_group: if not new_group:
raise HTTPNotFound('Group not found')
raise HTTPNotFound("Group not found")
old_group = topic.group old_group = topic.group
@ -231,18 +223,14 @@ def patch_move_topic(request: Request, path: str) -> dict:
LogEventType.TOPIC_MOVE, LogEventType.TOPIC_MOVE,
request, request,
topic, topic,
info={'old': str(old_group.path), 'new': str(topic.group.path)}
),
info={"old": str(old_group.path), "new": str(topic.group.path)},
)
) )
return Response('Moved')
return Response("Moved")
@ic_view_config(
route_name='topic_lock',
request_method='PUT',
permission='lock',
)
@ic_view_config(route_name="topic_lock", request_method="PUT", permission="lock")
def put_topic_lock(request: Request) -> Response: def put_topic_lock(request: Request) -> Response:
"""Lock a topic with Intercooler.""" """Lock a topic with Intercooler."""
topic = request.context topic = request.context
@ -250,14 +238,10 @@ def put_topic_lock(request: Request) -> Response:
topic.is_locked = True topic.is_locked = True
request.db_session.add(LogTopic(LogEventType.TOPIC_LOCK, request, topic)) request.db_session.add(LogTopic(LogEventType.TOPIC_LOCK, request, topic))
return Response('Locked')
return Response("Locked")
@ic_view_config(
route_name='topic_lock',
request_method='DELETE',
permission='lock',
)
@ic_view_config(route_name="topic_lock", request_method="DELETE", permission="lock")
def delete_topic_lock(request: Request) -> Response: def delete_topic_lock(request: Request) -> Response:
"""Unlock a topic with Intercooler.""" """Unlock a topic with Intercooler."""
topic = request.context topic = request.context
@ -265,27 +249,27 @@ def delete_topic_lock(request: Request) -> Response:
topic.is_locked = False topic.is_locked = False
request.db_session.add(LogTopic(LogEventType.TOPIC_UNLOCK, request, topic)) request.db_session.add(LogTopic(LogEventType.TOPIC_UNLOCK, request, topic))
return Response('Unlocked')
return Response("Unlocked")
@ic_view_config( @ic_view_config(
route_name='topic_title',
request_method='GET',
renderer='topic_title_edit.jinja2',
permission='edit_title',
route_name="topic_title",
request_method="GET",
renderer="topic_title_edit.jinja2",
permission="edit_title",
) )
def get_topic_title(request: Request) -> dict: def get_topic_title(request: Request) -> dict:
"""Get the form for editing a topic's title with Intercooler.""" """Get the form for editing a topic's title with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config( @ic_view_config(
route_name='topic',
request_param='ic-trigger-name=topic-title-edit',
request_method='PATCH',
permission='edit_title',
route_name="topic",
request_param="ic-trigger-name=topic-title-edit",
request_method="PATCH",
permission="edit_title",
) )
@use_kwargs(TopicSchema(only=('title',)))
@use_kwargs(TopicSchema(only=("title",)))
def patch_topic_title(request: Request, title: str) -> dict: def patch_topic_title(request: Request, title: str) -> dict:
"""Edit a topic's title with Intercooler.""" """Edit a topic's title with Intercooler."""
topic = request.context topic = request.context
@ -298,8 +282,8 @@ def patch_topic_title(request: Request, title: str) -> dict:
LogEventType.TOPIC_TITLE_EDIT, LogEventType.TOPIC_TITLE_EDIT,
request, request,
topic, topic,
info={'old': topic.title, 'new': title}
),
info={"old": topic.title, "new": title},
)
) )
topic.title = title topic.title = title

139
tildes/tildes/views/api/web/user.py

@ -20,52 +20,48 @@ from tildes.views import IC_NOOP
from tildes.views.decorators import ic_view_config from tildes.views.decorators import ic_view_config
PASSWORD_FIELD = UserSchema(only=('password',)).fields['password']
PASSWORD_FIELD = UserSchema(only=("password",)).fields["password"]
@ic_view_config( @ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=password-change',
permission='change_password',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=password-change",
permission="change_password",
)
@use_kwargs(
{
"old_password": PASSWORD_FIELD,
"new_password": PASSWORD_FIELD,
"new_password_confirm": PASSWORD_FIELD,
}
) )
@use_kwargs({
'old_password': PASSWORD_FIELD,
'new_password': PASSWORD_FIELD,
'new_password_confirm': PASSWORD_FIELD,
})
def patch_change_password( def patch_change_password(
request: Request,
old_password: str,
new_password: str,
new_password_confirm: str,
request: Request, old_password: str, new_password: str, new_password_confirm: str
) -> Response: ) -> Response:
"""Change the logged-in user's password.""" """Change the logged-in user's password."""
user = request.context user = request.context
# enable checking the new password against the breached-passwords list # enable checking the new password against the breached-passwords list
user.schema.context['check_breached_passwords'] = True
user.schema.context["check_breached_passwords"] = True
if new_password != new_password_confirm: if new_password != new_password_confirm:
raise HTTPUnprocessableEntity(
'New password and confirmation do not match.')
raise HTTPUnprocessableEntity("New password and confirmation do not match.")
user.change_password(old_password, new_password) user.change_password(old_password, new_password)
return Response('Your password has been updated')
return Response("Your password has been updated")
@ic_view_config( @ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=account-recovery-email',
permission='change_email_address',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=account-recovery-email",
permission="change_email_address",
) )
@use_kwargs(UserSchema(only=('email_address', 'email_address_note')))
@use_kwargs(UserSchema(only=("email_address", "email_address_note")))
def patch_change_email_address( def patch_change_email_address(
request: Request,
email_address: str,
email_address_note: str
request: Request, email_address: str, email_address_note: str
) -> Response: ) -> Response:
"""Change the user's email address (and descriptive note).""" """Change the user's email address (and descriptive note)."""
user = request.context user = request.context
@ -77,46 +73,46 @@ def patch_change_email_address(
log_info = None log_info = None
if user.email_address_hash: if user.email_address_hash:
log_info = { log_info = {
'old_hash': user.email_address_hash,
'old_note': user.email_address_note,
"old_hash": user.email_address_hash,
"old_note": user.email_address_note,
} }
request.db_session.add(Log(LogEventType.USER_EMAIL_SET, request, log_info)) request.db_session.add(Log(LogEventType.USER_EMAIL_SET, request, log_info))
user.email_address = email_address user.email_address = email_address
user.email_address_note = email_address_note user.email_address_note = email_address_note
return Response('Your email address has been updated')
return Response("Your email address has been updated")
@ic_view_config( @ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=auto-mark-notifications-read',
permission='change_auto_mark_notifications_read_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=auto-mark-notifications-read",
permission="change_auto_mark_notifications_read_setting",
) )
def patch_change_auto_mark_notifications(request: Request) -> Response: def patch_change_auto_mark_notifications(request: Request) -> Response:
"""Change the user's "automatically mark notifications read" setting.""" """Change the user's "automatically mark notifications read" setting."""
user = request.context user = request.context
auto_mark = bool(request.params.get('auto_mark_notifications_read'))
auto_mark = bool(request.params.get("auto_mark_notifications_read"))
user.auto_mark_notifications_read = auto_mark user.auto_mark_notifications_read = auto_mark
return IC_NOOP return IC_NOOP
@ic_view_config( @ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=open-links-new-tab',
permission='change_open_links_new_tab_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=open-links-new-tab",
permission="change_open_links_new_tab_setting",
) )
def patch_change_open_links_new_tab(request: Request) -> Response: def patch_change_open_links_new_tab(request: Request) -> Response:
"""Change the user's "open links in new tabs" setting.""" """Change the user's "open links in new tabs" setting."""
user = request.context user = request.context
external = bool(request.params.get('open_new_tab_external'))
internal = bool(request.params.get('open_new_tab_internal'))
text = bool(request.params.get('open_new_tab_text'))
external = bool(request.params.get("open_new_tab_external"))
internal = bool(request.params.get("open_new_tab_internal"))
text = bool(request.params.get("open_new_tab_text"))
user.open_new_tab_external = external user.open_new_tab_external = external
user.open_new_tab_internal = internal user.open_new_tab_internal = internal
user.open_new_tab_text = text user.open_new_tab_text = text
@ -125,16 +121,16 @@ def patch_change_open_links_new_tab(request: Request) -> Response:
@ic_view_config( @ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=comment-visits',
permission='change_comment_visits_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=comment-visits",
permission="change_comment_visits_setting",
) )
def patch_change_track_comment_visits(request: Request) -> Response: def patch_change_track_comment_visits(request: Request) -> Response:
"""Change the user's "track comment visits" setting.""" """Change the user's "track comment visits" setting."""
user = request.context user = request.context
track_comment_visits = bool(request.params.get('track_comment_visits'))
track_comment_visits = bool(request.params.get("track_comment_visits"))
user.track_comment_visits = track_comment_visits user.track_comment_visits = track_comment_visits
if track_comment_visits: if track_comment_visits:
@ -144,20 +140,20 @@ def patch_change_track_comment_visits(request: Request) -> Response:
@ic_view_config( @ic_view_config(
route_name='user_invite_code',
request_method='GET',
permission='view_invite_code',
renderer='invite_code.jinja2',
route_name="user_invite_code",
request_method="GET",
permission="view_invite_code",
renderer="invite_code.jinja2",
) )
def get_invite_code(request: Request) -> dict: def get_invite_code(request: Request) -> dict:
"""Generate a new invite code owned by the user.""" """Generate a new invite code owned by the user."""
user = request.context user = request.context
if request.user.invite_codes_remaining < 1: if request.user.invite_codes_remaining < 1:
raise HTTPForbidden('No invite codes remaining')
raise HTTPForbidden("No invite codes remaining")
# obtain a lock to prevent concurrent requests generating multiple codes # obtain a lock to prevent concurrent requests generating multiple codes
request.obtain_lock('generate_invite_code', user.user_id)
request.obtain_lock("generate_invite_code", user.user_id)
# it's possible to randomly generate an existing code, so we'll retry # it's possible to randomly generate an existing code, so we'll retry
# until we create a new one (will practically always be the first try) # until we create a new one (will practically always be the first try)
@ -179,22 +175,19 @@ def get_invite_code(request: Request) -> dict:
num_remaining = request.user.invite_codes_remaining - 1 num_remaining = request.user.invite_codes_remaining - 1
request.user.invite_codes_remaining = User.invite_codes_remaining - 1 request.user.invite_codes_remaining = User.invite_codes_remaining - 1
return {'code': code, 'num_remaining': num_remaining}
return {"code": code, "num_remaining": num_remaining}
@ic_view_config( @ic_view_config(
route_name='user_default_listing_options',
request_method='PUT',
permission='edit_default_listing_options',
route_name="user_default_listing_options",
request_method="PUT",
permission="edit_default_listing_options",
)
@use_kwargs(
{"order": Enum(TopicSortOption), "period": ShortTimePeriod(allow_none=True)}
) )
@use_kwargs({
'order': Enum(TopicSortOption),
'period': ShortTimePeriod(allow_none=True),
})
def put_default_listing_options( def put_default_listing_options(
request: Request,
order: TopicSortOption,
period: Optional[ShortTimePeriod],
request: Request, order: TopicSortOption, period: Optional[ShortTimePeriod]
) -> dict: ) -> dict:
"""Set the user's default listing options.""" """Set the user's default listing options."""
user = request.context user = request.context
@ -203,31 +196,31 @@ def put_default_listing_options(
if period: if period:
user.home_default_period = period.as_short_form() user.home_default_period = period.as_short_form()
else: else:
user.home_default_period = 'all'
user.home_default_period = "all"
return IC_NOOP return IC_NOOP
@ic_view_config( @ic_view_config(
route_name='user_filtered_topic_tags',
request_method='PUT',
permission='edit_filtered_topic_tags',
route_name="user_filtered_topic_tags",
request_method="PUT",
permission="edit_filtered_topic_tags",
) )
@use_kwargs({'tags': String()})
@use_kwargs({"tags": String()})
def put_filtered_topic_tags(request: Request, tags: str) -> dict: def put_filtered_topic_tags(request: Request, tags: str) -> dict:
"""Update a user's filtered topic tags list.""" """Update a user's filtered topic tags list."""
if not tags: if not tags:
request.user.filtered_topic_tags = [] request.user.filtered_topic_tags = []
return IC_NOOP return IC_NOOP
split_tags = tags.split(',')
split_tags = tags.split(",")
try: try:
schema = TopicSchema(only=('tags',))
result = schema.load({'tags': split_tags})
schema = TopicSchema(only=("tags",))
result = schema.load({"tags": split_tags})
except ValidationError: except ValidationError:
raise ValidationError({'tags': ['Invalid tags']})
raise ValidationError({"tags": ["Invalid tags"]})
request.user.filtered_topic_tags = result.data['tags']
request.user.filtered_topic_tags = result.data["tags"]
return IC_NOOP return IC_NOOP

16
tildes/tildes/views/decorators.py

@ -9,15 +9,15 @@ from pyramid.view import view_config
def ic_view_config(**kwargs: Any) -> Callable: def ic_view_config(**kwargs: Any) -> Callable:
"""Wrap the @view_config decorator for Intercooler views.""" """Wrap the @view_config decorator for Intercooler views."""
if 'route_name' in kwargs:
kwargs['route_name'] = 'ic_' + kwargs['route_name']
if "route_name" in kwargs:
kwargs["route_name"] = "ic_" + kwargs["route_name"]
if 'renderer' in kwargs:
kwargs['renderer'] = 'intercooler/' + kwargs['renderer']
if "renderer" in kwargs:
kwargs["renderer"] = "intercooler/" + kwargs["renderer"]
if 'header' in kwargs:
if "header" in kwargs:
raise ValueError("Can't add a header check to Intercooler view.") raise ValueError("Can't add a header check to Intercooler view.")
kwargs['header'] = 'X-IC-Request:true'
kwargs["header"] = "X-IC-Request:true"
return view_config(**kwargs) return view_config(**kwargs)
@ -32,6 +32,7 @@ def rate_limit_view(action_name: str) -> Callable:
response with appropriate headers will be raised instead of calling the response with appropriate headers will be raised instead of calling the
decorated view. decorated view.
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
request = args[0] request = args[0]
@ -55,9 +56,10 @@ def not_logged_in(func: Callable) -> Callable:
such as the login page, registration page, etc. which only logged-out users such as the login page, registration page, etc. which only logged-out users
should be accessing. should be accessing.
""" """
def wrapper(request: Request, **kwargs: Any) -> Any: def wrapper(request: Request, **kwargs: Any) -> Any:
if request.user: if request.user:
raise HTTPFound(location=request.route_url('home'))
raise HTTPFound(location=request.route_url("home"))
return func(request, **kwargs) return func(request, **kwargs)

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save