Browse Source

Apply Black code formatter

This commit contains only changes that were made automatically by Black
(except for some minor fixes to string un-wrapping and two
format-disabling blocks in the user and group schemas). Some manual
cleanup/adjustments will probably need to be made in a follow-up commit,
but this one contains the result of running Black on the codebase
without significant further manual tweaking.
merge-requests/26/head
Deimos 6 years ago
parent
commit
09cf3c47f4
  1. 22
      tildes/alembic/env.py
  2. 26
      tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py
  3. 13
      tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py
  4. 35
      tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py
  5. 124
      tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py
  6. 21
      tildes/consumers/comment_user_mentions_generator.py
  7. 29
      tildes/consumers/topic_metadata_generator.py
  8. 57
      tildes/scripts/breached_passwords.py
  9. 51
      tildes/scripts/clean_private_data.py
  10. 50
      tildes/scripts/initialize_db.py
  11. 6
      tildes/setup.py
  12. 71
      tildes/tests/conftest.py
  13. 8
      tildes/tests/fixtures.py
  14. 64
      tildes/tests/test_comment.py
  15. 72
      tildes/tests/test_comment_user_mentions.py
  16. 14
      tildes/tests/test_datetime.py
  17. 41
      tildes/tests/test_group.py
  18. 10
      tildes/tests/test_id.py
  19. 167
      tildes/tests/test_markdown.py
  20. 22
      tildes/tests/test_markdown_field.py
  21. 36
      tildes/tests/test_messages.py
  22. 2
      tildes/tests/test_metrics.py
  23. 36
      tildes/tests/test_ratelimit.py
  24. 22
      tildes/tests/test_simplestring_field.py
  25. 58
      tildes/tests/test_string.py
  26. 36
      tildes/tests/test_title.py
  27. 44
      tildes/tests/test_topic.py
  28. 37
      tildes/tests/test_topic_permissions.py
  29. 22
      tildes/tests/test_topic_tags.py
  30. 6
      tildes/tests/test_triggers_comments.py
  31. 22
      tildes/tests/test_url.py
  32. 51
      tildes/tests/test_user.py
  33. 16
      tildes/tests/test_username.py
  34. 10
      tildes/tests/test_webassets.py
  35. 6
      tildes/tests/webtests/test_user_page.py
  36. 93
      tildes/tildes/__init__.py
  37. 6
      tildes/tildes/api.py
  38. 46
      tildes/tildes/auth.py
  39. 27
      tildes/tildes/database.py
  40. 20
      tildes/tildes/enums.py
  41. 26
      tildes/tildes/jinja.py
  42. 6
      tildes/tildes/json.py
  43. 16
      tildes/tildes/lib/amqp.py
  44. 12
      tildes/tildes/lib/cmark.py
  45. 26
      tildes/tildes/lib/database.py
  46. 30
      tildes/tildes/lib/datetime.py
  47. 3
      tildes/tildes/lib/hash.py
  48. 10
      tildes/tildes/lib/id.py
  49. 181
      tildes/tildes/lib/markdown.py
  50. 2
      tildes/tildes/lib/message.py
  51. 11
      tildes/tildes/lib/password.py
  52. 95
      tildes/tildes/lib/ratelimit.py
  53. 45
      tildes/tildes/lib/string.py
  54. 4
      tildes/tildes/lib/url.py
  55. 60
      tildes/tildes/metrics.py
  56. 97
      tildes/tildes/models/comment/comment.py
  57. 71
      tildes/tildes/models/comment/comment_notification.py
  58. 10
      tildes/tildes/models/comment/comment_notification_query.py
  59. 6
      tildes/tildes/models/comment/comment_query.py
  60. 29
      tildes/tildes/models/comment/comment_tag.py
  61. 23
      tildes/tildes/models/comment/comment_tree.py
  62. 20
      tildes/tildes/models/comment/comment_vote.py
  63. 34
      tildes/tildes/models/database_model.py
  64. 44
      tildes/tildes/models/group/group.py
  65. 6
      tildes/tildes/models/group/group_query.py
  66. 20
      tildes/tildes/models/group/group_subscription.py
  67. 135
      tildes/tildes/models/log/log.py
  68. 85
      tildes/tildes/models/message/message.py
  69. 37
      tildes/tildes/models/model_query.py
  70. 14
      tildes/tildes/models/pagination.py
  71. 174
      tildes/tildes/models/topic/topic.py
  72. 45
      tildes/tildes/models/topic/topic_query.py
  73. 26
      tildes/tildes/models/topic/topic_visit.py
  74. 20
      tildes/tildes/models/topic/topic_vote.py
  75. 70
      tildes/tildes/models/user/user.py
  76. 16
      tildes/tildes/models/user/user_group_settings.py
  77. 37
      tildes/tildes/models/user/user_invite_code.py
  78. 6
      tildes/tildes/resources/__init__.py
  79. 16
      tildes/tildes/resources/comment.py
  80. 11
      tildes/tildes/resources/group.py
  81. 11
      tildes/tildes/resources/message.py
  82. 9
      tildes/tildes/resources/topic.py
  83. 5
      tildes/tildes/resources/user.py
  84. 160
      tildes/tildes/routes.py
  85. 45
      tildes/tildes/schemas/fields.py
  86. 27
      tildes/tildes/schemas/group.py
  87. 58
      tildes/tildes/schemas/topic.py
  88. 13
      tildes/tildes/schemas/topic_listing.py
  89. 46
      tildes/tildes/schemas/user.py
  90. 2
      tildes/tildes/views/__init__.py
  91. 2
      tildes/tildes/views/api/v0/group.py
  92. 4
      tildes/tildes/views/api/v0/topic.py
  93. 2
      tildes/tildes/views/api/v0/user.py
  94. 167
      tildes/tildes/views/api/web/comment.py
  95. 21
      tildes/tildes/views/api/web/exceptions.py
  96. 41
      tildes/tildes/views/api/web/group.py
  97. 16
      tildes/tildes/views/api/web/message.py
  98. 172
      tildes/tildes/views/api/web/topic.py
  99. 139
      tildes/tildes/views/api/web/user.py
  100. 16
      tildes/tildes/views/decorators.py

22
tildes/alembic/env.py

@ -12,12 +12,7 @@ config = context.config
fileConfig(config.config_file_name)
# import all DatabaseModel subclasses here for autogenerate support
from tildes.models.comment import (
Comment,
CommentNotification,
CommentTag,
CommentVote,
)
from tildes.models.comment import Comment, CommentNotification, CommentTag, CommentVote
from tildes.models.group import Group, GroupSubscription
from tildes.models.log import Log
from tildes.models.message import MessageConversation, MessageReply
@ -25,6 +20,7 @@ from tildes.models.topic import Topic, TopicVisit, TopicVote
from tildes.models.user import User, UserGroupSettings, UserInviteCode
from tildes.models import DatabaseModel
target_metadata = DatabaseModel.metadata
# other values from the config, defined by the needs of env.py,
@ -46,8 +42,7 @@ def run_migrations_offline():
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True)
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction():
context.run_migrations()
@ -62,18 +57,17 @@ def run_migrations_online():
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:

26
tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py

@ -10,17 +10,33 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2512581c91b3'
revision = "2512581c91b3"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
op.add_column('users', sa.Column('open_new_tab_external', sa.Boolean(), server_default='false', nullable=False))
op.add_column('users', sa.Column('open_new_tab_internal', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_external",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
op.add_column(
"users",
sa.Column(
"open_new_tab_internal",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
def downgrade():
op.drop_column('users', 'open_new_tab_internal')
op.drop_column('users', 'open_new_tab_external')
op.drop_column("users", "open_new_tab_internal")
op.drop_column("users", "open_new_tab_external")

13
tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py

@ -10,15 +10,20 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'de83b8750123'
down_revision = '2512581c91b3'
revision = "de83b8750123"
down_revision = "2512581c91b3"
branch_labels = None
depends_on = None
def upgrade():
op.add_column('users', sa.Column('open_new_tab_text', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_text", sa.Boolean(), server_default="false", nullable=False
),
)
def downgrade():
op.drop_column('users', 'open_new_tab_text')
op.drop_column("users", "open_new_tab_text")

35
tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py

@ -9,8 +9,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = 'f1ecbf24c212'
down_revision = 'de83b8750123'
revision = "f1ecbf24c212"
down_revision = "de83b8750123"
branch_labels = None
depends_on = None
@ -20,18 +20,18 @@ def upgrade():
connection = None
if not op.get_context().as_sql:
connection = op.get_bind()
connection.execution_options(isolation_level='AUTOCOMMIT')
connection.execution_options(isolation_level="AUTOCOMMIT")
op.execute(
"ALTER TYPE commentnotificationtype "
"ADD VALUE IF NOT EXISTS 'USER_MENTION'"
"ALTER TYPE commentnotificationtype ADD VALUE IF NOT EXISTS 'USER_MENTION'"
)
# re-activate the transaction for any future migrations
if connection is not None:
connection.execution_options(isolation_level='READ_COMMITTED')
connection.execution_options(isolation_level="READ_COMMITTED")
op.execute('''
op.execute(
"""
CREATE OR REPLACE FUNCTION send_rabbitmq_message_for_comment() RETURNS TRIGGER AS $$
DECLARE
affected_row RECORD;
@ -50,23 +50,28 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_insert
AFTER INSERT ON comments
FOR EACH ROW
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('created');
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_edit
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.markdown IS DISTINCT FROM NEW.markdown)
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('edited');
''')
"""
)
def downgrade():
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments')
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments')
op.execute('DROP FUNCTION send_rabbitmq_message_for_comment')
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments")
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments")
op.execute("DROP FUNCTION send_rabbitmq_message_for_comment")

124
tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py

@ -10,8 +10,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fab922a8bb04'
down_revision = 'f1ecbf24c212'
revision = "fab922a8bb04"
down_revision = "f1ecbf24c212"
branch_labels = None
depends_on = None
@ -19,17 +19,20 @@ depends_on = None
def upgrade():
# comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted = false AND NEW.is_deleted = true)
OR (OLD.is_removed = false AND NEW.is_removed = true))
EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN
IF (NEW.is_deleted = TRUE) THEN
@ -41,17 +44,21 @@ def upgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_removed_time() RETURNS TRIGGER AS $$
BEGIN
IF (NEW.is_removed = TRUE) THEN
@ -63,19 +70,23 @@ def upgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER remove_comment_set_removed_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_removed IS DISTINCT FROM NEW.is_removed)
EXECUTE PROCEDURE set_comment_removed_time();
""")
"""
)
# topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION decrement_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
DECLARE
old_visible BOOLEAN := NOT (OLD.is_deleted OR OLD.is_removed);
@ -96,18 +107,22 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_all_topic_visit_num_comments();
""")
"""
)
# topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN
IF (TG_OP = 'INSERT') THEN
@ -140,18 +155,22 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE
most_recent_comment RECORD;
@ -182,31 +201,37 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)
def downgrade():
# comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN
NEW.deleted_time := current_timestamp;
@ -214,15 +239,18 @@ def downgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("DROP TRIGGER remove_comment_set_removed_time_update ON comments")
op.execute("DROP FUNCTION set_comment_removed_time()")
@ -230,7 +258,8 @@ def downgrade():
# topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION update_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION decrement_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
BEGIN
UPDATE topic_visits
@ -241,17 +270,21 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE decrement_all_topic_visit_num_comments();
""")
"""
)
# topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN
IF (TG_OP = 'INSERT' AND NEW.is_deleted = FALSE) THEN
@ -277,17 +310,21 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE
most_recent_comment RECORD;
@ -317,12 +354,15 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)

21
tildes/consumers/comment_user_mentions_generator.py

@ -13,7 +13,7 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
"""Process a delivered message."""
comment = (
self.db_session.query(Comment)
.filter_by(comment_id=msg.body['comment_id'])
.filter_by(comment_id=msg.body["comment_id"])
.one()
)
@ -22,15 +22,16 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
return
new_mentions = CommentNotification.get_mentions_for_comment(
self.db_session, comment)
self.db_session, comment
)
if msg.delivery_info['routing_key'] == 'comment.created':
if msg.delivery_info["routing_key"] == "comment.created":
for user_mention in new_mentions:
self.db_session.add(user_mention)
elif msg.delivery_info['routing_key'] == 'comment.edited':
to_delete, to_add = (
CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions))
elif msg.delivery_info["routing_key"] == "comment.edited":
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions
)
for user_mention in to_delete:
self.db_session.delete(user_mention)
@ -39,8 +40,8 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
self.db_session.add(user_mention)
if __name__ == '__main__':
if __name__ == "__main__":
CommentUserMentionGenerator(
queue_name='comment_user_mentions_generator.q',
routing_keys=['comment.created', 'comment.edited'],
queue_name="comment_user_mentions_generator.q",
routing_keys=["comment.created", "comment.edited"],
).consume_queue()

29
tildes/consumers/topic_metadata_generator.py

@ -26,9 +26,7 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
def run(self, msg: Message) -> None:
"""Process a delivered message."""
topic = (
self.db_session.query(Topic)
.filter_by(topic_id=msg.body['topic_id'])
.one()
self.db_session.query(Topic).filter_by(topic_id=msg.body["topic_id"]).one()
)
if topic.is_text_type:
@ -42,22 +40,19 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
html_tree = HTMLParser().parseFragment(topic.rendered_html)
# extract the text from all of the HTML elements
extracted_text = ''.join(
[element_text for element_text in html_tree.itertext()])
extracted_text = "".join(
[element_text for element_text in html_tree.itertext()]
)
# sanitize unicode, remove leading/trailing whitespace, etc.
extracted_text = simplify_string(extracted_text)
# create a short excerpt by truncating the simplified string
excerpt = truncate_string(
extracted_text,
length=200,
truncate_at_chars=' ',
)
excerpt = truncate_string(extracted_text, length=200, truncate_at_chars=" ")
topic.content_metadata = {
'word_count': word_count(extracted_text),
'excerpt': excerpt,
"word_count": word_count(extracted_text),
"excerpt": excerpt,
}
def _generate_link_metadata(self, topic: Topic) -> None:
@ -68,13 +63,11 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
parsed_domain = get_domain_from_url(topic.link)
domain = self.public_suffix_list.get_public_suffix(parsed_domain)
topic.content_metadata = {
'domain': domain,
}
topic.content_metadata = {"domain": domain}
if __name__ == '__main__':
if __name__ == "__main__":
TopicMetadataGenerator(
queue_name='topic_metadata_generator.q',
routing_keys=['topic.created', 'topic.edited'],
queue_name="topic_metadata_generator.q",
routing_keys=["topic.created", "topic.edited"],
).consume_queue()

57
tildes/scripts/breached_passwords.py

@ -46,11 +46,11 @@ def generate_redis_protocol(*elements: Any) -> str:
Based on the example Ruby code from
https://redis.io/topics/mass-insert#generating-redis-protocol
"""
command = f'*{len(elements)}\r\n'
command = f"*{len(elements)}\r\n"
for element in elements:
element = str(element)
command += f'${len(element)}\r\n{element}\r\n'
command += f"${len(element)}\r\n{element}\r\n"
return command
@ -65,27 +65,27 @@ def validate_init_error_rate(ctx: Any, param: Any, value: Any) -> float:
"""Validate the --error-rate arg for the init command."""
# pylint: disable=unused-argument
if not 0 < value < 1:
raise click.BadParameter('error rate must be a float between 0 and 1')
raise click.BadParameter("error rate must be a float between 0 and 1")
return value
@cli.command(help='Initialize a new empty bloom filter')
@cli.command(help="Initialize a new empty bloom filter")
@click.option(
'--estimate',
"--estimate",
required=True,
type=int,
help='Expected number of passwords that will be added',
help="Expected number of passwords that will be added",
)
@click.option(
'--error-rate',
"--error-rate",
default=0.01,
show_default=True,
help='Bloom filter desired false positive ratio',
help="Bloom filter desired false positive ratio",
callback=validate_init_error_rate,
)
@click.confirmation_option(
prompt='Are you sure you want to clear any existing bloom filter?',
prompt="Are you sure you want to clear any existing bloom filter?"
)
def init(estimate: int, error_rate: float) -> None:
"""Initialize a new bloom filter (destroying any existing one).
@ -102,22 +102,16 @@ def init(estimate: int, error_rate: float) -> None:
REDIS.delete(BREACHED_PASSWORDS_BF_KEY)
# BF.RESERVE {key} {error_rate} {size}
REDIS.execute_command(
'BF.RESERVE',
BREACHED_PASSWORDS_BF_KEY,
error_rate,
estimate,
)
REDIS.execute_command("BF.RESERVE", BREACHED_PASSWORDS_BF_KEY, error_rate, estimate)
click.echo(
'Initialized bloom filter with expected size of {:,} and false '
'positive rate of {}%'
.format(estimate, error_rate * 100)
"Initialized bloom filter with expected size of {:,} and false "
"positive rate of {}%".format(estimate, error_rate * 100)
)
@cli.command(help='Add hashes from a file to the bloom filter')
@click.argument('filename', type=click.Path(exists=True, dir_okay=False))
@cli.command(help="Add hashes from a file to the bloom filter")
@click.argument("filename", type=click.Path(exists=True, dir_okay=False))
def addhashes(filename: str) -> None:
"""Add all hashes from a file to the bloom filter.
@ -127,26 +121,26 @@ def addhashes(filename: str) -> None:
"""
# make sure the key exists and is a bloom filter
try:
REDIS.execute_command('BF.DEBUG', BREACHED_PASSWORDS_BF_KEY)
REDIS.execute_command("BF.DEBUG", BREACHED_PASSWORDS_BF_KEY)
except ResponseError:
click.echo('Bloom filter is not set up properly - run init first.')
click.echo("Bloom filter is not set up properly - run init first.")
raise click.Abort
# call wc to count the number of lines in the file for the progress bar
click.echo('Determining hash count...')
result = subprocess.run(['wc', '-l', filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b' ')[0])
click.echo("Determining hash count...")
result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b" ")[0])
progress_bar: Any = click.progressbar(length=line_count)
update_interval = 100_000
click.echo('Adding {:,} hashes to bloom filter...'.format(line_count))
click.echo("Adding {:,} hashes to bloom filter...".format(line_count))
redis_pipe = subprocess.Popen(
['redis-cli', '-s', BREACHED_PASSWORDS_REDIS_SOCKET, '--pipe'],
["redis-cli", "-s", BREACHED_PASSWORDS_REDIS_SOCKET, "--pipe"],
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
encoding='utf-8',
encoding="utf-8",
)
for count, line in enumerate(open(filename), start=1):
@ -155,10 +149,9 @@ def addhashes(filename: str) -> None:
# the Pwned Passwords hash lists now have a frequency count for each
# hash, which is separated from the hash with a colon, so we need to
# handle that if it's present
hashval = hashval.split(':')[0]
hashval = hashval.split(":")[0]
command = generate_redis_protocol(
'BF.ADD', BREACHED_PASSWORDS_BF_KEY, hashval)
command = generate_redis_protocol("BF.ADD", BREACHED_PASSWORDS_BF_KEY, hashval)
redis_pipe.stdin.write(command)
if count % update_interval == 0:
@ -173,5 +166,5 @@ def addhashes(filename: str) -> None:
progress_bar.render_finish()
if __name__ == '__main__':
if __name__ == "__main__":
cli()

51
tildes/scripts/clean_private_data.py

@ -33,22 +33,17 @@ def clean_all_data(config_path: str) -> None:
cleaner.clean_all()
class DataCleaner():
class DataCleaner:
"""Container class for all methods related to cleaning up old data."""
def __init__(
self,
db_session: Session,
retention_period: timedelta,
) -> None:
def __init__(self, db_session: Session, retention_period: timedelta) -> None:
"""Create a new DataCleaner."""
self.db_session = db_session
self.retention_cutoff = datetime.now() - retention_period
def clean_all(self) -> None:
"""Call all the cleanup functions."""
logging.info(
f'Cleaning up all data (retention cutoff {self.retention_cutoff})')
logging.info(f"Cleaning up all data (retention cutoff {self.retention_cutoff})")
self.delete_old_log_entries()
self.delete_old_topic_visits()
@ -68,7 +63,7 @@ class DataCleaner():
.delete(synchronize_session=False)
)
self.db_session.commit()
logging.info(f'Deleted {deleted} old log entries.')
logging.info(f"Deleted {deleted} old log entries.")
def delete_old_topic_visits(self) -> None:
"""Delete all topic visits older than the retention cutoff."""
@ -78,7 +73,7 @@ class DataCleaner():
.delete(synchronize_session=False)
)
self.db_session.commit()
logging.info(f'Deleted {deleted} old topic visits.')
logging.info(f"Deleted {deleted} old topic visits.")
def clean_old_deleted_comments(self) -> None:
"""Clean the data of old deleted comments.
@ -92,14 +87,13 @@ class DataCleaner():
Comment.deleted_time <= self.retention_cutoff, # type: ignore
Comment.user_id != 0,
)
.update({
'user_id': 0,
'markdown': '',
'rendered_html': '',
}, synchronize_session=False)
.update(
{"user_id": 0, "markdown": "", "rendered_html": ""},
synchronize_session=False,
)
)
self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted comments.')
logging.info(f"Cleaned {updated} old deleted comments.")
def clean_old_deleted_topics(self) -> None:
"""Clean the data of old deleted topics.
@ -113,16 +107,19 @@ class DataCleaner():
Topic.deleted_time <= self.retention_cutoff, # type: ignore
Topic.user_id != 0,
)
.update({
'user_id': 0,
'title': '',
'topic_type': 'TEXT',
'markdown': None,
'rendered_html': None,
'link': None,
'content_metadata': None,
'_tags': [],
}, synchronize_session=False)
.update(
{
"user_id": 0,
"title": "",
"topic_type": "TEXT",
"markdown": None,
"rendered_html": None,
"link": None,
"content_metadata": None,
"_tags": [],
},
synchronize_session=False,
)
)
self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted topics.')
logging.info(f"Cleaned {updated} old deleted topics.")

50
tildes/scripts/initialize_db.py

@ -15,27 +15,24 @@ from tildes.models.log import Log
from tildes.models.user import User
def initialize_db(
config_path: str,
alembic_config_path: Optional[str] = None,
) -> None:
def initialize_db(config_path: str, alembic_config_path: Optional[str] = None) -> None:
"""Load the app config and create the database tables."""
db_session = get_session_from_config(config_path)
engine = db_session.bind
create_tables(engine)
run_sql_scripts_in_dir('sql/init/', engine)
run_sql_scripts_in_dir("sql/init/", engine)
# if an Alembic config file wasn't specified, assume it's alembic.ini in
# the same directory
if not alembic_config_path:
path = os.path.split(config_path)[0]
alembic_config_path = os.path.join(path, 'alembic.ini')
alembic_config_path = os.path.join(path, "alembic.ini")
# mark current Alembic revision in db so migrations start from this point
alembic_cfg = Config(alembic_config_path)
command.stamp(alembic_cfg, 'head')
command.stamp(alembic_cfg, "head")
def create_tables(connectable: Connectable) -> None:
@ -44,7 +41,8 @@ def create_tables(connectable: Connectable) -> None:
excluded_tables = Log.INHERITED_TABLES
tables = [
table for table in DatabaseModel.metadata.tables.values()
table
for table in DatabaseModel.metadata.tables.values()
if table.name not in excluded_tables
]
DatabaseModel.metadata.create_all(connectable, tables=tables)
@ -53,29 +51,31 @@ def create_tables(connectable: Connectable) -> None:
def run_sql_scripts_in_dir(path: str, engine: Engine) -> None:
"""Run all sql scripts in a directory."""
for root, _, files in os.walk(path):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files:
subprocess.call([
'psql',
'-U', engine.url.username,
'-f', os.path.join(root, sql_file),
engine.url.database,
])
subprocess.call(
[
"psql",
"-U",
engine.url.username,
"-f",
os.path.join(root, sql_file),
engine.url.database,
]
)
def insert_dev_data(config_path: str) -> None:
"""Load the app config and insert some "starter" data for a dev version."""
session = get_session_from_config(config_path)
session.add_all([
User('TestUser', 'password'),
Group(
'testing',
'An automatically created group to use for testing purposes',
),
])
session.add_all(
[
User("TestUser", "password"),
Group(
"testing", "An automatically created group to use for testing purposes"
),
]
)
session.commit()

6
tildes/setup.py

@ -4,11 +4,11 @@ from setuptools import find_packages, setup
setup(
name='tildes',
version='0.1',
name="tildes",
version="0.1",
packages=find_packages(),
entry_points="""
[paste.app_factory]
main = tildes:main
"""
""",
)

71
tildes/tests/conftest.py

@ -18,7 +18,7 @@ from tildes.models.user import User
# include the fixtures defined in fixtures.py
pytest_plugins = ['tests.fixtures']
pytest_plugins = ["tests.fixtures"]
class NestedSessionWrapper(Session):
@ -40,25 +40,25 @@ class NestedSessionWrapper(Session):
super().rollback()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def pyramid_config():
"""Set up the Pyramid environment."""
settings = get_appsettings('development.ini')
settings = get_appsettings("development.ini")
config = testing.setUp(settings=settings)
config.include('tildes.auth')
config.include("tildes.auth")
yield config
testing.tearDown()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_db_session(pyramid_config):
"""Handle setup and teardown of test database for testing session."""
# read the database url from the pyramid INI file, and replace the db name
sqlalchemy_url = pyramid_config.registry.settings['sqlalchemy.url']
sqlalchemy_url = pyramid_config.registry.settings["sqlalchemy.url"]
parsed_url = make_url(sqlalchemy_url)
parsed_url.database = 'tildes_test'
parsed_url.database = "tildes_test"
engine = create_engine(parsed_url)
session_factory = sessionmaker(bind=engine)
@ -69,12 +69,9 @@ def overall_db_session(pyramid_config):
# SQL init scripts need to be executed "manually" instead of using psql
# like the normal database init process does, since the tables only exist
# inside this transaction
init_scripts_dir = 'sql/init/'
init_scripts_dir = "sql/init/"
for root, _, files in os.walk(init_scripts_dir):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files:
with open(os.path.join(root, sql_file)) as current_file:
session.execute(current_file.read())
@ -90,7 +87,7 @@ def overall_db_session(pyramid_config):
session.rollback()
@fixture(scope='session')
@fixture(scope="session")
def sdb(overall_db_session):
"""Testing-session-level db session with a nested transaction."""
overall_db_session.begin_nested()
@ -100,7 +97,7 @@ def sdb(overall_db_session):
overall_db_session.rollback_all_nested()
@fixture(scope='function')
@fixture(scope="function")
def db(overall_db_session):
"""Function-level db session with a nested transaction."""
overall_db_session.begin_nested()
@ -110,25 +107,23 @@ def db(overall_db_session):
overall_db_session.rollback_all_nested()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_redis_session():
"""Create a session-level connection to a temporary redis server."""
# list of redis modules that need to be loaded (would be much nicer to do
# this automatically somehow, maybe reading from the real redis.conf?)
redis_modules = [
'/opt/redis-cell/libredis_cell.so',
]
redis_modules = ["/opt/redis-cell/libredis_cell.so"]
with RedisServer() as temp_redis_server:
redis = StrictRedis(**temp_redis_server.dsn())
for module in redis_modules:
redis.execute_command('MODULE LOAD', module)
redis.execute_command("MODULE LOAD", module)
yield redis
@fixture(scope='function')
@fixture(scope="function")
def redis(overall_redis_session):
"""Create a function-level redis connection that wipes the db after use."""
yield overall_redis_session
@ -136,47 +131,47 @@ def redis(overall_redis_session):
overall_redis_session.flushdb()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user(sdb):
"""Create a user named 'SessionUser' in the db for test session."""
# note that some tests may depend on this username/password having these
# specific values, so make sure to search for and update those tests if you
# change the username or password for any reason
user = User('SessionUser', 'session user password')
user = User("SessionUser", "session user password")
sdb.add(user)
sdb.commit()
yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user2(sdb):
"""Create a second user named 'OtherUser' in the db for test session.
This is useful for cases where two different users are needed, such as
when testing private messages.
"""
user = User('OtherUser', 'other user password')
user = User("OtherUser", "other user password")
sdb.add(user)
sdb.commit()
yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_group(sdb):
"""Create a group named 'sessiongroup' in the db for test session."""
group = Group('sessiongroup')
group = Group("sessiongroup")
sdb.add(group)
sdb.commit()
yield group
@fixture(scope='session')
@fixture(scope="session")
def base_app(overall_redis_session, sdb):
"""Configure a base WSGI app that webtest can create TestApps based on."""
testing_app = get_app('development.ini')
testing_app = get_app("development.ini")
# replace the redis connection used by the redis-sessions library with a
# connection to the temporary server for this test session
@ -185,38 +180,38 @@ def base_app(overall_redis_session, sdb):
def redis_factory(request):
# pylint: disable=unused-argument
return overall_redis_session
testing_app.app.registry['redis_connection_factory'] = redis_factory
testing_app.app.registry["redis_connection_factory"] = redis_factory
# replace the session factory function with one that will return the
# testing db session (inside a nested transaction)
def session_factory():
return sdb
testing_app.app.registry['db_session_factory'] = session_factory
testing_app.app.registry["db_session_factory"] = session_factory
yield testing_app
@fixture(scope='session')
@fixture(scope="session")
def webtest(base_app):
"""Create a webtest TestApp and log in as the SessionUser account in it."""
# create the TestApp - note that specifying wsgi.url_scheme is necessary
# so that the secure cookies from the session library will work
app = TestApp(
base_app,
extra_environ={'wsgi.url_scheme': 'https'},
cookiejar=CookieJar(),
base_app, extra_environ={"wsgi.url_scheme": "https"}, cookiejar=CookieJar()
)
# fetch the login page, fill in the form, and submit it (sets the cookie)
login_page = app.get('/login')
login_page.form['username'] = 'SessionUser'
login_page.form['password'] = 'session user password'
login_page = app.get("/login")
login_page.form["username"] = "SessionUser"
login_page.form["password"] = "session user password"
login_page.form.submit()
yield app
@fixture(scope='session')
@fixture(scope="session")
def webtest_loggedout(base_app):
"""Create a logged-out webtest TestApp (no cookies retained)."""
yield TestApp(base_app)

8
tildes/tests/fixtures.py

@ -8,7 +8,8 @@ from tildes.models.topic import Topic
def text_topic(db, session_group, session_user):
"""Create a text topic, delete it as teardown (including comments)."""
new_topic = Topic.create_text_topic(
session_group, session_user, 'A Text Topic', 'the text')
session_group, session_user, "A Text Topic", "the text"
)
db.add(new_topic)
db.commit()
@ -23,7 +24,8 @@ def text_topic(db, session_group, session_user):
def link_topic(db, session_group, session_user):
"""Create a link topic, delete it as teardown (including comments)."""
new_topic = Topic.create_link_topic(
session_group, session_user, 'A Link Topic', 'http://example.com')
session_group, session_user, "A Link Topic", "http://example.com"
)
db.add(new_topic)
db.commit()
@ -43,7 +45,7 @@ def topic(text_topic):
@fixture
def comment(db, session_user, topic):
"""Create a comment in the database, delete it as teardown."""
new_comment = Comment(topic, session_user, 'A comment')
new_comment = Comment(topic, session_user, "A comment")
db.add(new_comment)
db.commit()

64
tildes/tests/test_comment.py

@ -1,81 +1,73 @@
from datetime import timedelta
from freezegun import freeze_time
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
from tildes.enums import CommentSortOption
from tildes.lib.datetime import utc_now
from tildes.models.comment import (
Comment,
CommentTree,
EDIT_GRACE_PERIOD,
)
from tildes.models.comment import Comment, CommentTree, EDIT_GRACE_PERIOD
from tildes.schemas.comment import CommentSchema
from tildes.schemas.fields import Markdown
def test_comment_creation_validates_schema(mocker, session_user, topic):
"""Ensure that comment creation goes through schema validation."""
mocker.spy(CommentSchema, 'load')
mocker.spy(CommentSchema, "load")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
call_args = CommentSchema.load.call_args[0]
assert {'markdown': 'A test comment'} in call_args
assert {"markdown": "A test comment"} in call_args
def test_comment_creation_uses_markdown_field(mocker, session_user, topic):
"""Ensure the Markdown field class is validating new comments."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
assert Markdown._validate.called
def test_comment_edit_uses_markdown_field(mocker, comment):
"""Ensure editing a comment is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
comment.markdown = 'Some new text after edit'
comment.markdown = "Some new text after edit"
assert Markdown._validate.called
def test_edit_markdown_updates_html(comment):
"""Ensure editing a comment works and the markdown and HTML update."""
comment.markdown = 'Updated comment'
assert 'Updated' in comment.markdown
assert 'Updated' in comment.rendered_html
comment.markdown = "Updated comment"
assert "Updated" in comment.markdown
assert "Updated" in comment.rendered_html
def test_comment_viewing_permission(comment):
"""Ensure that anyone can view a comment by default."""
assert Everyone in principals_allowed_by_permission(comment, 'view')
assert Everyone in principals_allowed_by_permission(comment, "view")
def test_comment_editing_permission(comment):
"""Ensure that only the comment's author can edit it."""
principals = principals_allowed_by_permission(comment, 'edit')
principals = principals_allowed_by_permission(comment, "edit")
assert principals == {comment.user_id}
def test_comment_deleting_permission(comment):
"""Ensure that only the comment's author can delete it."""
principals = principals_allowed_by_permission(comment, 'delete')
principals = principals_allowed_by_permission(comment, "delete")
assert principals == {comment.user_id}
def test_comment_replying_permission(comment):
"""Ensure that any authenticated user can reply to a comment."""
assert Authenticated in principals_allowed_by_permission(comment, 'reply')
assert Authenticated in principals_allowed_by_permission(comment, "reply")
def test_comment_reply_locked_thread_permission(comment):
"""Ensure that only admins can reply in locked threads."""
comment.topic.is_locked = True
assert principals_allowed_by_permission(comment, 'reply') == {'admin'}
assert principals_allowed_by_permission(comment, "reply") == {"admin"}
def test_deleted_comment_permissions_removed(comment):
@ -90,8 +82,8 @@ def test_deleted_comment_permissions_removed(comment):
def test_removed_comment_view_permission(comment):
"""Ensure a removed comment can only be viewed by its author and admins."""
comment.is_removed = True
principals = principals_allowed_by_permission(comment, 'view')
assert principals == {'admin', comment.user_id}
principals = principals_allowed_by_permission(comment, "view")
assert principals == {"admin", comment.user_id}
def test_edit_grace_period(comment):
@ -100,7 +92,7 @@ def test_edit_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert not comment.last_edited_time
@ -111,7 +103,7 @@ def test_edit_after_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert comment.last_edited_time == utc_now()
@ -123,7 +115,7 @@ def test_multiple_edits_update_time(comment):
for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time):
comment.markdown = f'edit #{minutes}'
comment.markdown = f"edit #{minutes}"
assert comment.last_edited_time == utc_now()
@ -134,8 +126,8 @@ def test_comment_tree(db, topic, session_user):
sort = CommentSortOption.POSTED
# add two root comments
root = Comment(topic, session_user, 'root')
root2 = Comment(topic, session_user, 'root2')
root = Comment(topic, session_user, "root")
root2 = Comment(topic, session_user, "root2")
all_comments.extend([root, root2])
db.add_all(all_comments)
db.commit()
@ -151,8 +143,8 @@ def test_comment_tree(db, topic, session_user):
assert tree == [root]
# add two replies to the remaining root comment
child = Comment(topic, session_user, '1', parent_comment=root)
child2 = Comment(topic, session_user, '2', parent_comment=root)
child = Comment(topic, session_user, "1", parent_comment=root)
child2 = Comment(topic, session_user, "2", parent_comment=root)
all_comments.extend([child, child2])
db.add_all(all_comments)
db.commit()
@ -165,8 +157,8 @@ def test_comment_tree(db, topic, session_user):
assert child2.replies == []
# add two more replies to the second depth-1 comment
subchild = Comment(topic, session_user, '2a', parent_comment=child2)
subchild2 = Comment(topic, session_user, '2b', parent_comment=child2)
subchild = Comment(topic, session_user, "2a", parent_comment=child2)
subchild2 = Comment(topic, session_user, "2b", parent_comment=child2)
all_comments.extend([subchild, subchild2])
db.add_all(all_comments)
db.commit()

72
tildes/tests/test_comment_user_mentions.py

@ -3,10 +3,7 @@ from pytest import fixture
from sqlalchemy import and_
from tildes.enums import CommentNotificationType
from tildes.models.comment import (
Comment,
CommentNotification,
)
from tildes.models.comment import Comment, CommentNotification
from tildes.models.topic import Topic
from tildes.models.user import User
@ -15,8 +12,8 @@ from tildes.models.user import User
def user_list(db):
"""Create several users."""
users = []
for name in ['foo', 'bar', 'baz']:
user = User(name, 'password')
for name in ["foo", "bar", "baz"]:
user = User(name, "password")
users.append(user)
db.add(user)
db.commit()
@ -30,44 +27,40 @@ def user_list(db):
def test_get_mentions_for_comment(db, user_list, comment):
"""Test that notifications are generated and returned."""
comment.markdown = '@foo @bar. @baz!'
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment.markdown = "@foo @bar. @baz!"
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 3
for index, user in enumerate(user_list):
assert mentions[index].user == user
def test_mention_filtering_parent_comment(
mocker, db, topic, user_list):
def test_mention_filtering_parent_comment(mocker, db, topic, user_list):
"""Test notification filtering for parent comments."""
parent_comment = Comment(topic, user_list[0], 'Comment content.')
parent_comment = Comment(topic, user_list[0], "Comment content.")
parent_comment.user_id = user_list[0].user_id
comment = mocker.Mock(
user_id=user_list[1].user_id,
markdown=f'@{user_list[0].username}',
markdown=f"@{user_list[0].username}",
parent_comment=parent_comment,
)
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
def test_mention_filtering_self_mention(db, user_list, topic):
"""Test notification filtering for self-mentions."""
comment = Comment(topic, user_list[0], f'@{user_list[0]}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment = Comment(topic, user_list[0], f"@{user_list[0]}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
def test_mention_filtering_top_level(db, user_list, session_group):
"""Test notification filtering for top-level comments."""
topic = Topic.create_text_topic(
session_group, user_list[0], 'Some title', 'some text')
comment = Comment(topic, user_list[1], f'@{user_list[0].username}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
session_group, user_list[0], "Some title", "some text"
)
comment = Comment(topic, user_list[1], f"@{user_list[0].username}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
@ -82,36 +75,35 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
4. The comment is deleted.
"""
# 1
comment = Comment(topic, user_list[0], f'@{user_list[1].username}')
comment = Comment(topic, user_list[0], f"@{user_list[1].username}")
db.add(comment)
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1
assert mentions[0].user == user_list[1]
db.add_all(mentions)
db.commit()
# 2
comment.markdown = f'@{user_list[2].username}'
comment.markdown = f"@{user_list[2].username}"
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert len(to_delete) == 1
assert mentions == to_add
assert to_delete[0].user.username == user_list[1].username
# 3
comment.markdown = f'@{user_list[1].username} @{user_list[2].username}'
comment.markdown = f"@{user_list[1].username} @{user_list[2].username}"
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 2
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert not to_delete
assert len(to_add) == 1
@ -120,9 +112,13 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
db.commit()
notifications = (
db.query(CommentNotification.user_id)
.filter(and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
)).all())
.filter(
and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
)
.all()
)
assert not notifications

14
tildes/tests/test_datetime.py

@ -20,40 +20,40 @@ def test_utc_now_accurate():
def test_descriptive_timedelta_basic():
"""Ensure a simple descriptive timedelta works correctly."""
test_time = utc_now() - timedelta(hours=3)
assert descriptive_timedelta(test_time) == '3 hours ago'
assert descriptive_timedelta(test_time) == "3 hours ago"
def test_more_precise_longer_descriptive_timedelta():
"""Ensure a longer time period gets the extra precision level."""
test_time = utc_now() - timedelta(days=2, hours=5)
assert descriptive_timedelta(test_time) == '2 days, 5 hours ago'
assert descriptive_timedelta(test_time) == "2 days, 5 hours ago"
def test_no_small_precision_descriptive_timedelta():
"""Ensure the extra precision doesn't apply to small units."""
test_time = utc_now() - timedelta(days=6, minutes=10)
assert descriptive_timedelta(test_time) == '6 days ago'
assert descriptive_timedelta(test_time) == "6 days ago"
def test_single_precision_below_an_hour():
"""Ensure times under an hour only have one precision level."""
test_time = utc_now() - timedelta(minutes=59, seconds=59)
assert descriptive_timedelta(test_time) == '59 minutes ago'
assert descriptive_timedelta(test_time) == "59 minutes ago"
def test_more_precision_above_an_hour():
"""Ensure the second precision level gets added just above an hour."""
test_time = utc_now() - timedelta(hours=1, minutes=1)
assert descriptive_timedelta(test_time) == '1 hour, 1 minute ago'
assert descriptive_timedelta(test_time) == "1 hour, 1 minute ago"
def test_subsecond_descriptive_timedelta():
"""Ensure time less than a second returns the special phrase."""
test_time = utc_now() - timedelta(microseconds=100)
assert descriptive_timedelta(test_time) == 'a moment ago'
assert descriptive_timedelta(test_time) == "a moment ago"
def test_above_second_descriptive_timedelta():
"""Ensure it starts describing time in seconds above 1 second."""
test_time = utc_now() - timedelta(seconds=1, microseconds=100)
assert descriptive_timedelta(test_time) == '1 second ago'
assert descriptive_timedelta(test_time) == "1 second ago"

41
tildes/tests/test_group.py

@ -3,46 +3,43 @@ from sqlalchemy.exc import IntegrityError
from tildes.models.group import Group
from tildes.schemas.fields import Ltree, SimpleString
from tildes.schemas.group import (
GroupSchema,
is_valid_group_path,
)
from tildes.schemas.group import GroupSchema, is_valid_group_path
def test_empty_path_invalid():
"""Ensure empty group path is invalid."""
assert not is_valid_group_path('')
assert not is_valid_group_path("")
def test_typical_path_valid():
"""Ensure a "normal-looking" group path is valid."""
assert is_valid_group_path('games.video.nintendo_3ds')
assert is_valid_group_path("games.video.nintendo_3ds")
def test_start_with_underscore():
"""Ensure you can't start a path with an underscore."""
assert not is_valid_group_path('_x.y.z')
assert not is_valid_group_path("_x.y.z")
def test_middle_element_start_with_underscore():
"""Ensure a middle path element can't start with an underscore."""
assert not is_valid_group_path('x._y.z')
assert not is_valid_group_path("x._y.z")
def test_end_with_underscore():
"""Ensure you can't end a path with an underscore."""
assert not is_valid_group_path('x.y.z_')
assert not is_valid_group_path("x.y.z_")
def test_middle_element_end_with_underscore():
"""Ensure a middle path element can't end with an underscore."""
assert not is_valid_group_path('x.y_.z')
assert not is_valid_group_path("x.y_.z")
def test_uppercase_letters_invalid():
"""Ensure a group path can't contain uppercase chars."""
assert is_valid_group_path('comp.lang.c')
assert not is_valid_group_path('comp.lang.C')
assert is_valid_group_path("comp.lang.c")
assert not is_valid_group_path("comp.lang.C")
def test_paths_with_invalid_characters():
@ -50,34 +47,34 @@ def test_paths_with_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,?/'
for char in invalid_chars:
path = f'abc{char}xyz'
path = f"abc{char}xyz"
assert not is_valid_group_path(path)
def test_paths_with_unicode_characters():
"""Ensure that paths can't use unicode chars (not comprehensive)."""
for path in ('games.pokémon', 'ポケモン', 'bites.møøse'):
for path in ("games.pokémon", "ポケモン", "bites.møøse"):
assert not is_valid_group_path(path)
def test_creation_validates_schema(mocker):
"""Ensure that group creation goes through expected validation."""
mocker.spy(GroupSchema, 'load')
mocker.spy(Ltree, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(GroupSchema, "load")
mocker.spy(Ltree, "_validate")
mocker.spy(SimpleString, "_validate")
Group('testing', 'with a short description')
Group("testing", "with a short description")
assert GroupSchema.load.called
assert Ltree._validate.call_args[0][1] == 'testing'
assert SimpleString._validate.call_args[0][1] == 'with a short description'
assert Ltree._validate.call_args[0][1] == "testing"
assert SimpleString._validate.call_args[0][1] == "with a short description"
def test_duplicate_group(db):
"""Ensure groups with duplicate paths can't be created."""
original = Group('twins')
original = Group("twins")
db.add(original)
duplicate = Group('twins')
duplicate = Group("twins")
db.add(duplicate)
with raises(IntegrityError):

10
tildes/tests/test_id.py

@ -7,12 +7,12 @@ from tildes.lib.id import id_to_id36, id36_to_id
def test_id_to_id36():
"""Make sure an ID->ID36 conversion is correct."""
assert id_to_id36(571049189) == '9fzkdh'
assert id_to_id36(571049189) == "9fzkdh"
def test_id36_to_id():
"""Make sure an ID36->ID conversion is correct."""
assert id36_to_id('x48l4z') == 2002502915
assert id36_to_id("x48l4z") == 2002502915
def test_reversed_conversion_from_id():
@ -23,7 +23,7 @@ def test_reversed_conversion_from_id():
def test_reversed_conversion_from_id36():
"""Make sure an ID36->ID->ID36 conversion returns to original value."""
original = 'h2l4pe'
original = "h2l4pe"
assert id_to_id36(id36_to_id(original)) == original
@ -36,7 +36,7 @@ def test_zero_id_conversion_blocked():
def test_zero_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept zero."""
with raises(ValueError):
id36_to_id('0')
id36_to_id("0")
def test_negative_id_conversion_blocked():
@ -48,4 +48,4 @@ def test_negative_id_conversion_blocked():
def test_negative_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept negative numbers."""
with raises(ValueError):
id36_to_id('-1')
id36_to_id("-1")

167
tildes/tests/test_markdown.py

@ -3,29 +3,29 @@ from tildes.lib.markdown import convert_markdown_to_safe_html
def test_script_tag_escaped():
"""Ensure that a <script> tag can't get through."""
markdown = '<script>alert()</script>'
markdown = "<script>alert()</script>"
sanitized = convert_markdown_to_safe_html(markdown)
assert '<script>' not in sanitized
assert "<script>" not in sanitized
def test_basic_markdown_unescaped():
"""Test that some common markdown comes through without escaping."""
markdown = (
"# Here's a header.\n\n"
'This chunk of text has **some bold** and *some italics* in it.\n\n'
'A separator will be below this paragraph.\n\n'
'---\n\n'
'* An unordered list item\n'
'* Another list item\n\n'
'> This should be a quote.\n\n'
' And a code block\n\n'
'Also some `inline code` and [a link](http://example.com).\n\n'
'And a manual break \nbetween lines.\n\n'
"This chunk of text has **some bold** and *some italics* in it.\n\n"
"A separator will be below this paragraph.\n\n"
"---\n\n"
"* An unordered list item\n"
"* Another list item\n\n"
"> This should be a quote.\n\n"
" And a code block\n\n"
"Also some `inline code` and [a link](http://example.com).\n\n"
"And a manual break \nbetween lines.\n\n"
)
sanitized = convert_markdown_to_safe_html(markdown)
assert '&lt;' not in sanitized
assert "&lt;" not in sanitized
def test_strikethrough():
@ -33,23 +33,23 @@ def test_strikethrough():
markdown = "This ~should not~ should work"
processed = convert_markdown_to_safe_html(markdown)
assert '<del>' in processed
assert '<a' not in processed
assert "<del>" in processed
assert "<a" not in processed
def test_table():
"""Ensure table markdown works."""
markdown = (
'|Header 1|Header 2|Header 3|\n'
'|--------|-------:|:------:|\n'
'|1 - 1 |1 - 2 |1 - 3 |\n'
'|2 - 1|2 - 2|2 - 3|\n'
"|Header 1|Header 2|Header 3|\n"
"|--------|-------:|:------:|\n"
"|1 - 1 |1 - 2 |1 - 3 |\n"
"|2 - 1|2 - 2|2 - 3|\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert '<table>' in processed
assert processed.count('<tr') == 3
assert processed.count('<td') == 6
assert "<table>" in processed
assert processed.count("<tr") == 3
assert processed.count("<td") == 6
assert 'align="right"' in processed
assert 'align="center"' in processed
@ -57,35 +57,35 @@ def test_table():
def test_deliberate_ordered_list():
"""Ensure a "deliberate" ordered list works."""
markdown = (
'My first line of text.\n\n'
'1. I want\n'
'2. An ordered\n'
'3. List here\n\n'
'A final line.'
"My first line of text.\n\n"
"1. I want\n"
"2. An ordered\n"
"3. List here\n\n"
"A final line."
)
html = convert_markdown_to_safe_html(markdown)
assert '<ol>' in html
assert "<ol>" in html
def test_accidental_ordered_list():
"""Ensure a common "accidental" ordered list gets escaped."""
markdown = (
'What year did this happen?\n\n'
'1975. It was a long time ago.\n\n'
'But I remember it like it was yesterday.'
"What year did this happen?\n\n"
"1975. It was a long time ago.\n\n"
"But I remember it like it was yesterday."
)
html = convert_markdown_to_safe_html(markdown)
assert '<ol' not in html
assert "<ol" not in html
def test_existing_newline_not_doubled():
"""Ensure that the standard markdown line break doesn't result in two."""
markdown = 'A deliberate line \nbreak'
markdown = "A deliberate line \nbreak"
html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == 1
assert html.count("<br") == 1
def test_newline_creates_br():
@ -93,36 +93,31 @@ def test_newline_creates_br():
markdown = "This wouldn't\nnormally work"
html = convert_markdown_to_safe_html(markdown)
assert '<br>' in html
assert "<br>" in html
def test_multiple_newlines():
"""Ensure markdown with multiple newlines has expected result."""
lines = ["One.", "Two.", "Three.", "Four.", "Five."]
markdown = '\n'.join(lines)
markdown = "\n".join(lines)
html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == len(lines) - 1
assert html.count("<br") == len(lines) - 1
assert all(line in html for line in lines)
def test_newline_in_code_block():
"""Ensure newlines in code blocks don't add a <br>."""
markdown = (
'```\n'
'def testing_for_newlines():\n'
' pass\n'
'```\n'
)
markdown = "```\ndef testing_for_newlines():\n pass\n```\n"
html = convert_markdown_to_safe_html(markdown)
assert '<br' not in html
assert "<br" not in html
def test_http_link_linkified():
"""Ensure that writing an http url results in a link."""
markdown = 'I like http://example.com as an example.'
markdown = "I like http://example.com as an example."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed
@ -130,7 +125,7 @@ def test_http_link_linkified():
def test_https_link_linkified():
"""Ensure that writing an https url results in a link."""
markdown = 'Also, https://example.com should work.'
markdown = "Also, https://example.com should work."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="https://example.com">' in processed
@ -138,7 +133,7 @@ def test_https_link_linkified():
def test_bare_domain_linkified():
"""Ensure that a bare domain results in a link."""
markdown = 'I can just write example.com too.'
markdown = "I can just write example.com too."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed
@ -146,7 +141,7 @@ def test_bare_domain_linkified():
def test_link_with_path_linkified():
"""Ensure a link with a path results in a link."""
markdown = 'So http://example.com/a/b_c_d/e too?'
markdown = "So http://example.com/a/b_c_d/e too?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com/a/b_c_d/e">' in processed
@ -154,7 +149,7 @@ def test_link_with_path_linkified():
def test_link_with_query_string_linkified():
"""Ensure a link with a query string results in a link."""
markdown = 'Also http://example.com?something=true works?'
markdown = "Also http://example.com?something=true works?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com?something=true">' in processed
@ -162,21 +157,21 @@ def test_link_with_query_string_linkified():
def test_email_address_not_linkified():
"""Ensure that an email address does not get linkified."""
markdown = 'Please contact somebody@example.com about that.'
markdown = "Please contact somebody@example.com about that."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_other_protocol_urls_not_linkified():
"""Ensure some other protocols don't linkify (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols:
markdown = f'Testing {protocol}://example.com for linking'
markdown = f"Testing {protocol}://example.com for linking"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_html_attr_whitelist_violation():
@ -187,23 +182,23 @@ def test_html_attr_whitelist_violation():
)
processed = convert_markdown_to_safe_html(markdown)
assert processed == '<p>test link</p>\n'
assert processed == "<p>test link</p>\n"
def test_a_href_protocol_violation():
"""Ensure link to other protocols removes the link (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols:
markdown = f'Testing [a link]({protocol}://example.com) for linking'
markdown = f"Testing [a link]({protocol}://example.com) for linking"
processed = convert_markdown_to_safe_html(markdown)
assert 'href' not in processed
assert "href" not in processed
def test_group_reference_linkified():
"""Ensure a simple group reference gets linkified."""
markdown = 'Yeah, I saw that in ~books.fantasy yesterday.'
markdown = "Yeah, I saw that in ~books.fantasy yesterday."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~books.fantasy">' in processed
@ -212,14 +207,14 @@ def test_group_reference_linkified():
def test_multiple_group_references_linkified():
"""Ensure multiple group references are all linkified."""
markdown = (
'I like to keep an eye on:\n\n'
'* ~music.metal\n'
'* ~music.metal.progressive\n'
'* ~music.post_rock\n'
"I like to keep an eye on:\n\n"
"* ~music.metal\n"
"* ~music.metal.progressive\n"
"* ~music.post_rock\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 3
assert processed.count("<a") == 3
def test_invalid_group_reference_not_linkified():
@ -230,20 +225,20 @@ def test_invalid_group_reference_not_linkified():
)
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_approximately_tilde_not_linkified():
"""Ensure a tilde in front of a number doesn't linkify."""
markdown = 'Mix in ~2 cups of flour and ~1.5 tbsp of sugar.'
markdown = "Mix in ~2 cups of flour and ~1.5 tbsp of sugar."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_uppercase_group_ref_links_correctly():
"""Ensure using uppercase in a group ref works but links correctly."""
markdown = 'That was in ~Music.Metal.Progressive'
markdown = "That was in ~Music.Metal.Progressive"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~music.metal.progressive' in processed
@ -260,29 +255,29 @@ def test_existing_link_group_ref_not_replaced():
def test_group_ref_inside_link_not_replaced():
"""Ensure a group ref inside a longer link doesn't get re-linked."""
markdown = 'Found [this band from a ~music.punk post](http://whitelung.ca)'
markdown = "Found [this band from a ~music.punk post](http://whitelung.ca)"
processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 1
assert processed.count("<a") == 1
assert 'href="/~music.punk"' not in processed
def test_group_ref_inside_pre_ignored():
"""Ensure a group ref inside a <pre> tag doesn't get linked."""
markdown = (
'```\n'
'# This is a code block\n'
'# I found this code on ~comp.lang.python\n'
'```\n'
"```\n"
"# This is a code block\n"
"# I found this code on ~comp.lang.python\n"
"```\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_group_ref_inside_other_tags_linkified():
"""Ensure a group ref inside non-ignored tags gets linked."""
markdown = '> Here is **a ~group.reference inside** other stuff'
markdown = "> Here is **a ~group.reference inside** other stuff"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~group.reference">' in processed
@ -290,7 +285,7 @@ def test_group_ref_inside_other_tags_linkified():
def test_username_reference_linkified():
"""Ensure a basic username reference gets linkified."""
markdown = 'Hey @SomeUser, what do you think of this?'
markdown = "Hey @SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -298,7 +293,7 @@ def test_username_reference_linkified():
def test_u_style_username_ref_linked():
"""Ensure a /u/username reference gets linkified."""
markdown = 'Hey /u/SomeUser, what do you think of this?'
markdown = "Hey /u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">/u/SomeUser</a>' in processed
@ -306,7 +301,7 @@ def test_u_style_username_ref_linked():
def test_u_alt_style_username_ref_linked():
"""Ensure a u/username reference gets linkified."""
markdown = 'Hey u/SomeUser, what do you think of this?'
markdown = "Hey u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">u/SomeUser</a>' in processed
@ -314,15 +309,15 @@ def test_u_alt_style_username_ref_linked():
def test_accidental_u_alt_style_not_linked():
"""Ensure an "accidental" u/ usage won't get linked."""
markdown = 'I think those are caribou/reindeer.'
markdown = "I think those are caribou/reindeer."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_and_group_refs_linked():
"""Ensure username and group references together get linkified."""
markdown = '@SomeUser makes the best posts in ~some.group for sure'
markdown = "@SomeUser makes the best posts in ~some.group for sure"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -334,16 +329,12 @@ def test_invalid_username_not_linkified():
markdown = "You can't register a username like @_underscores_"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_ref_inside_pre_ignored():
"""Ensure a username ref inside a <pre> tag doesn't get linked."""
markdown = (
'```\n'
'# Code blatantly stolen from @HelpfulGuy on StackOverflow\n'
'```\n'
)
markdown = "```\n# Code blatantly stolen from @HelpfulGuy on StackOverflow\n```\n"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed

22
tildes/tests/test_markdown_field.py

@ -12,53 +12,53 @@ class MarkdownFieldTestSchema(Schema):
def validate_string(string):
"""Validate a string against a standard Markdown field."""
MarkdownFieldTestSchema(strict=True).validate({'markdown': string})
MarkdownFieldTestSchema(strict=True).validate({"markdown": string})
def test_normal_text_validates():
"""Ensure some "normal-looking" markdown validates."""
validate_string(
"Here's some markdown.\n\n"
'It has **a bit of bold**, [a link](http://example.com)\n'
'> And `some code` in a blockquote'
"It has **a bit of bold**, [a link](http://example.com)\n"
"> And `some code` in a blockquote"
)
def test_changing_max_length():
"""Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate
assert Markdown()._validate(test_string) is None
# but fails if you set a too-short max_length
with raises(ValidationError):
Markdown(max_length=len(test_string)-1)._validate(test_string)
Markdown(max_length=len(test_string) - 1)._validate(test_string)
def test_extremely_long_string():
"""Ensure an extremely long string fails validation."""
with raises(ValidationError):
validate_string('A' * 100_000)
validate_string("A" * 100_000)
def test_empty_string():
"""Ensure an empty string fails validation."""
with raises(ValidationError):
validate_string('')
validate_string("")
def test_all_whitespace_string():
"""Ensure a string that's all whitespace chars fails validation."""
with raises(ValidationError):
validate_string(' \n \n\r\n \t ')
validate_string(" \n \n\r\n \t ")
def test_carriage_returns_stripped():
"""Ensure loading a value strips out carriage returns from the string."""
test_string = 'some\r\nreturns\r\nin\nhere'
test_string = "some\r\nreturns\r\nin\nhere"
schema = MarkdownFieldTestSchema(strict=True)
result = schema.load({'markdown': test_string})
result = schema.load({"markdown": test_string})
assert '\r' not in result.data['markdown']
assert "\r" not in result.data["markdown"]

36
tildes/tests/test_messages.py

@ -4,17 +4,15 @@ from pytest import fixture, raises
from tildes.models.message import MessageConversation, MessageReply
from tildes.models.user import User
from tildes.schemas.fields import Markdown, SimpleString
from tildes.schemas.message import (
MessageConversationSchema,
MessageReplySchema,
)
from tildes.schemas.message import MessageConversationSchema, MessageReplySchema
@fixture
def conversation(db, session_user, session_user2):
"""Create a message conversation and delete it as teardown."""
new_conversation = MessageConversation(
session_user, session_user2, 'Subject', 'Message')
session_user, session_user2, "Subject", "Message"
)
db.add(new_conversation)
db.commit()
@ -30,31 +28,31 @@ def conversation(db, session_user, session_user2):
def test_message_conversation_validation(mocker, session_user, session_user2):
"""Ensure a new message conversation goes through expected validation."""
mocker.spy(MessageConversationSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageConversationSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(Markdown, "_validate")
MessageConversation(session_user, session_user2, 'Subject', 'Message')
MessageConversation(session_user, session_user2, "Subject", "Message")
assert MessageConversationSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'Subject'
assert Markdown._validate.call_args[0][1] == 'Message'
assert SimpleString._validate.call_args[0][1] == "Subject"
assert Markdown._validate.call_args[0][1] == "Message"
def test_message_reply_validation(mocker, conversation, session_user2):
"""Ensure a new message reply goes through expected validation."""
mocker.spy(MessageReplySchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageReplySchema, "load")
mocker.spy(Markdown, "_validate")
MessageReply(conversation, session_user2, 'A new reply')
MessageReply(conversation, session_user2, "A new reply")
assert MessageReplySchema.load.called
assert Markdown._validate.call_args[0][1] == 'A new reply'
assert Markdown._validate.call_args[0][1] == "A new reply"
def test_conversation_viewing_permission(conversation):
"""Ensure only the two involved users can view a message conversation."""
principals = principals_allowed_by_permission(conversation, 'view')
principals = principals_allowed_by_permission(conversation, "view")
users = {conversation.sender.user_id, conversation.recipient.user_id}
assert principals == users
@ -70,7 +68,7 @@ def test_conversation_other_user(conversation):
def test_conversation_other_user_invalid(conversation):
"""Ensure that "other user" method fails if the user isn't involved."""
new_user = User('SomeOutsider', 'super amazing password')
new_user = User("SomeOutsider", "super amazing password")
with raises(ValueError):
assert conversation.other_user(new_user)
@ -82,7 +80,7 @@ def test_replies_affect_num_replies(conversation, db):
# add replies and ensure each one increases the count
for num in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply)
db.commit()
db.refresh(conversation)
@ -94,7 +92,7 @@ def test_replies_update_activity_time(conversation, db):
assert conversation.last_activity_time == conversation.created_time
for _ in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply)
db.commit()

2
tildes/tests/test_metrics.py

@ -9,4 +9,4 @@ def test_all_metric_names_prefixed():
# this is ugly, but seems to be the "generic" way to get the name
metric_name = metric.describe()[0].name
assert metric_name.startswith('tildes_')
assert metric_name.startswith("tildes_")

36
tildes/tests/test_ratelimit.py

@ -24,13 +24,12 @@ def test_all_rate_limited_action_names_unique():
def test_action_with_all_types_disabled():
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
with raises(ValueError):
RateLimitedAction(
'test', timedelta(hours=1), 5, by_user=False, by_ip=False)
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
def test_check_by_user_id_disabled():
"""Ensure non-by_user RateLimitedAction can't be checked by user_id."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_user=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
with raises(RateLimitError):
action.check_for_user_id(1)
@ -38,10 +37,10 @@ def test_check_by_user_id_disabled():
def test_check_by_ip_disabled():
"""Ensure non-by_ip RateLimitedAction can't be checked by ip."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_ip=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
with raises(RateLimitError):
action.check_for_ip('123.123.123.123')
action.check_for_ip("123.123.123.123")
def test_simple_rate_limiting_by_user_id(redis):
@ -51,7 +50,8 @@ def test_simple_rate_limiting_by_user_id(redis):
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
@ -68,7 +68,7 @@ def test_different_user_ids_limited_separately(redis):
limit = 5
user_id = 1
action = RateLimitedAction('test', timedelta(hours=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
# check the action for the first user_id until it's blocked
result = action.check_for_user_id(user_id)
@ -84,7 +84,7 @@ def test_max_burst_defaults_to_half(redis):
limit = 10
user_id = 1
action = RateLimitedAction('test', timedelta(days=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
# see how many times we can do the action until it gets blocked
count = 0
@ -107,7 +107,8 @@ def test_time_until_retry(redis):
# create an action with no burst allowed, which will force the actions to
# be spaced "evenly" across the limit
action = RateLimitedAction(
'test', period=period, limit=limit, max_burst=1, redis=redis)
"test", period=period, limit=limit, max_burst=1, redis=redis
)
# first usage should be fine
result = action.check_for_user_id(user_id)
@ -126,7 +127,8 @@ def test_remaining_limit(redis):
# create an action allowing the full limit as a burst
action = RateLimitedAction(
'test', timedelta(days=1), limit, max_burst=limit, redis=redis)
"test", timedelta(days=1), limit, max_burst=limit, redis=redis
)
for count in range(1, limit + 1):
result = action.check_for_user_id(user_id)
@ -136,11 +138,12 @@ def test_remaining_limit(redis):
def test_simple_rate_limiting_by_ip(redis):
"""Ensure simple rate-limiting by IP address is working."""
limit = 5
ip = '123.123.123.123'
ip = "123.123.123.123"
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
@ -154,9 +157,9 @@ def test_simple_rate_limiting_by_ip(redis):
def test_check_for_ip_invalid_address():
"""Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.check_for_ip(ip)
@ -164,9 +167,9 @@ def test_check_for_ip_invalid_address():
def test_reset_for_ip_invalid_address():
"""Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.reset_for_ip(ip)
@ -224,6 +227,7 @@ def test_merged_results():
def test_merged_all_allowed():
"""Ensure a merged result from all allowed results is also allowed."""
def random_allowed_result():
"""Return a RateLimitResult with is_allowed=True, otherwise random."""
return RateLimitResult(

22
tildes/tests/test_simplestring_field.py

@ -17,39 +17,39 @@ def process_string(string):
ValidationError if an invalid string is attempted.
"""
schema = SimpleStringTestSchema(strict=True)
result = schema.load({'subject': string})
result = schema.load({"subject": string})
return result.data['subject']
return result.data["subject"]
def test_changing_max_length():
"""Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate
assert SimpleString()._validate(test_string) is None
# but fails if you set a too-short max_length
with raises(ValidationError):
SimpleString(max_length=len(test_string)-1)._validate(test_string)
SimpleString(max_length=len(test_string) - 1)._validate(test_string)
def test_long_string():
"""Ensure a long string fails validation."""
with raises(ValidationError):
process_string('A' * 10_000)
process_string("A" * 10_000)
def test_empty_string():
"""Ensure an empty string fails validation."""
with raises(ValidationError):
process_string('')
process_string("")
def test_all_whitespace_string():
"""Ensure a string that's entirely whitespace fails validation."""
with raises(ValidationError):
process_string('\n \t \r\n ')
process_string("\n \t \r\n ")
def test_normal_string_untouched():
@ -76,11 +76,11 @@ def test_control_chars_removed():
def test_leading_trailing_spaces_removed():
"""Ensure leading/trailing spaces are removed from the string."""
original = ' Centered! '
assert process_string(original) == 'Centered!'
original = " Centered! "
assert process_string(original) == "Centered!"
def test_consecutive_spaces_collapsed():
"""Ensure runs of consecutive spaces are "collapsed" inside the string."""
original = 'I wanted to space this out'
assert process_string(original) == 'I wanted to space this out'
original = "I wanted to space this out"
assert process_string(original) == "I wanted to space this out"

58
tildes/tests/test_string.py

@ -8,71 +8,69 @@ from tildes.lib.string import (
def test_simple_truncate():
"""Ensure a simple truncation by length works correctly."""
truncated = truncate_string('123456789', 5, overflow_str=None)
assert truncated == '12345'
truncated = truncate_string("123456789", 5, overflow_str=None)
assert truncated == "12345"
def test_simple_truncate_with_overflow():
"""Ensure a simple truncation by length with an overflow string works."""
truncated = truncate_string('123456789', 5)
assert truncated == '12...'
truncated = truncate_string("123456789", 5)
assert truncated == "12..."
def test_truncate_same_length():
"""Ensure truncation doesn't happen if the string is the desired length."""
original = '123456789'
original = "123456789"
assert truncate_string(original, len(original)) == original
def test_truncate_at_char():
"""Ensure truncation at a particular character works."""
original = 'asdf zxcv'
assert truncate_string_at_char(original, ' ') == 'asdf'
original = "asdf zxcv"
assert truncate_string_at_char(original, " ") == "asdf"
def test_truncate_at_last_char():
"""Ensure truncation happens at the last occurrence of the character."""
original = 'as df zx cv'
assert truncate_string_at_char(original, ' ') == 'as df zx'
original = "as df zx cv"
assert truncate_string_at_char(original, " ") == "as df zx"
def test_truncate_at_nonexistent_char():
"""Ensure truncation-at-character doesn't apply if char isn't present."""
original = 'asdfzxcv'
assert truncate_string_at_char(original, ' ') == original
original = "asdfzxcv"
assert truncate_string_at_char(original, " ") == original
def test_truncate_at_multiple_chars():
"""Ensure truncation with multiple characters uses the rightmost one."""
original = 'as-df=zx_cv'
assert truncate_string_at_char(original, '-=') == 'as-df'
original = "as-df=zx_cv"
assert truncate_string_at_char(original, "-=") == "as-df"
def test_truncate_length_and_char():
"""Ensure combined length+char truncation works as expected."""
original = '12345-67890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345'
original = "12345-67890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345"
def test_truncate_length_and_nonexistent_char():
"""Ensure length+char truncation works if the char isn't present."""
original = '1234567890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345678'
original = "1234567890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345678"
def test_simple_url_slug_conversion():
"""Ensure that a simple url slug conversion works as expected."""
assert convert_to_url_slug("A Simple Test") == 'a_simple_test'
assert convert_to_url_slug("A Simple Test") == "a_simple_test"
def test_url_slug_with_punctuation():
"""Ensure url slug conversion with punctuation works as expected."""
original = "Here's a string. It has (some) punctuation!"
expected = 'heres_a_string_it_has_some_punctuation'
expected = "heres_a_string_it_has_some_punctuation"
assert convert_to_url_slug(original) == expected
@ -86,13 +84,13 @@ def test_url_slug_with_apostrophes():
def test_url_slug_truncation():
"""Ensure a simple url slug truncates as expected."""
original = "Here's another string to truncate."
assert convert_to_url_slug(original, 15) == 'heres_another'
assert convert_to_url_slug(original, 15) == "heres_another"
def test_multibyte_url_slug():
"""Ensure converting/truncating a slug with encoded characters works."""
original = 'Python ist eine üblicherweise höhere Programmiersprache'
expected = 'python_ist_eine_%C3%BCblicherweise'
original = "Python ist eine üblicherweise höhere Programmiersprache"
expected = "python_ist_eine_%C3%BCblicherweise"
assert convert_to_url_slug(original, 45) == expected
@ -101,7 +99,7 @@ def test_multibyte_conservative_truncation():
# this string has a comma as the 6th char which will be converted to an
# underscore, so if truncation amount isn't restricted, it would result in
# a 46-char slug instead of the full 100.
original = 'パイソンは、汎用のプログラミング言語である'
original = "パイソンは、汎用のプログラミング言語である"
assert len(convert_to_url_slug(original, 100)) == 100
@ -109,14 +107,14 @@ def test_multibyte_whole_character_truncation():
"""Ensure truncation happens at the edge of a multibyte character."""
# each of these characters url-encodes to 3 bytes = 9 characters each, so
# only the first character should be included for all lengths from 9 - 17
original = 'コード'
original = "コード"
for limit in range(9, 18):
assert convert_to_url_slug(original, limit) == '%E3%82%B3'
assert convert_to_url_slug(original, limit) == "%E3%82%B3"
def test_simple_word_count():
"""Ensure word-counting a simple string works as expected."""
string = 'Here is a simple string of words, nothing fancy.'
string = "Here is a simple string of words, nothing fancy."
assert word_count(string) == 9

36
tildes/tests/test_title.py

@ -7,57 +7,57 @@ from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
@fixture
def title_schema():
"""Fixture for generating a title-only TopicSchema."""
return TopicSchema(only=('title',))
return TopicSchema(only=("title",))
def test_typical_title_valid(title_schema):
"""Test a "normal-looking" title to make sure it's valid."""
title = "[Something] Here's an article that I'm sure 100 people will like."
assert title_schema.validate({'title': title}) == {}
assert title_schema.validate({"title": title}) == {}
def test_too_long_title_invalid(title_schema):
"""Ensure a too-long title is invalid."""
title = 'x' * (TITLE_MAX_LENGTH + 1)
title = "x" * (TITLE_MAX_LENGTH + 1)
with raises(ValidationError):
title_schema.validate({'title': title})
title_schema.validate({"title": title})
def test_empty_title_invalid(title_schema):
"""Ensure an empty title is invalid."""
with raises(ValidationError):
title_schema.validate({'title': ''})
title_schema.validate({"title": ""})
def test_whitespace_only_title_invalid(title_schema):
"""Ensure a whitespace-only title is invalid."""
with raises(ValidationError):
title_schema.validate({'title': ' \n '})
title_schema.validate({"title": " \n "})
def test_whitespace_trimmed(title_schema):
"""Ensure leading/trailing whitespace on a title is removed."""
title = ' actual title '
result = title_schema.load({'title': title})
assert result.data['title'] == 'actual title'
title = " actual title "
result = title_schema.load({"title": title})
assert result.data["title"] == "actual title"
def test_consecutive_whitespace_removed(title_schema):
"""Ensure consecutive whitespace in a title is compressed."""
title = 'sure are \n a lot of spaces'
result = title_schema.load({'title': title})
assert result.data['title'] == 'sure are a lot of spaces'
title = "sure are \n a lot of spaces"
result = title_schema.load({"title": title})
assert result.data["title"] == "sure are a lot of spaces"
def test_unicode_spaces_normalized(title_schema):
"""Test that some unicode space characters are converted to normal ones."""
title = 'some\u2009weird\u00a0spaces\u205fin\u00a0here'
result = title_schema.load({'title': title})
assert result.data['title'] == 'some weird spaces in here'
title = "some\u2009weird\u00a0spaces\u205fin\u00a0here"
result = title_schema.load({"title": title})
assert result.data["title"] == "some weird spaces in here"
def test_unicode_control_chars_removed(title_schema):
"""Test that some unicode control characters are stripped from titles."""
title = 'nothing\u0000strange\u0085going\u009con\u007fhere'
result = title_schema.load({'title': title})
assert result.data['title'] == 'nothingstrangegoingonhere'
title = "nothing\u0000strange\u0085going\u009con\u007fhere"
result = title_schema.load({"title": title})
assert result.data["title"] == "nothingstrangegoingonhere"

44
tildes/tests/test_topic.py

@ -12,41 +12,37 @@ from tildes.schemas.topic import TopicSchema
def test_text_creation_validations(mocker, session_user, session_group):
"""Ensure that text topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(Markdown, "_validate")
mocker.spy(SimpleString, "_validate")
Topic.create_text_topic(
session_group, session_user, 'a title', 'the text')
Topic.create_text_topic(session_group, session_user, "a title", "the text")
assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'a title'
assert Markdown._validate.call_args[0][1] == 'the text'
assert SimpleString._validate.call_args[0][1] == "a title"
assert Markdown._validate.call_args[0][1] == "the text"
def test_link_creation_validations(mocker, session_user, session_group):
"""Ensure that link topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(URL, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(URL, "_validate")
Topic.create_link_topic(
session_group,
session_user,
'the title',
'http://example.com',
session_group, session_user, "the title", "http://example.com"
)
assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'the title'
assert URL._validate.call_args[0][1] == 'http://example.com'
assert SimpleString._validate.call_args[0][1] == "the title"
assert URL._validate.call_args[0][1] == "http://example.com"
def test_text_topic_edit_uses_markdown_field(mocker, text_topic):
"""Ensure editing a text topic is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
text_topic.markdown = 'Some new text after edit'
text_topic.markdown = "Some new text after edit"
assert Markdown._validate.called
@ -82,19 +78,19 @@ def test_link_domain_errors_on_text_topic(text_topic):
def test_link_domain_on_link_topic(link_topic):
"""Ensure getting the domain of a link topic works."""
assert link_topic.link_domain == 'example.com'
assert link_topic.link_domain == "example.com"
def test_edit_markdown_errors_on_link_topic(link_topic):
"""Ensure trying to edit the markdown of a link topic is an error."""
with raises(AttributeError):
link_topic.markdown = 'Some new markdown'
link_topic.markdown = "Some new markdown"
def test_edit_markdown_on_text_topic(text_topic):
"""Ensure editing the markdown of a text topic works and updates html."""
original_html = text_topic.rendered_html
text_topic.markdown = 'Some new markdown'
text_topic.markdown = "Some new markdown"
assert text_topic.rendered_html != original_html
@ -104,7 +100,7 @@ def test_edit_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert not text_topic.last_edited_time
@ -115,7 +111,7 @@ def test_edit_after_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert text_topic.last_edited_time == utc_now()
@ -127,7 +123,7 @@ def test_multiple_edits_update_time(text_topic):
for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time):
text_topic.markdown = f'edit #{minutes}'
text_topic.markdown = f"edit #{minutes}"
assert text_topic.last_edited_time == utc_now()

37
tildes/tests/test_topic_permissions.py

@ -1,13 +1,9 @@
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
def test_topic_viewing_permission(text_topic):
"""Ensure that anyone can view a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'view')
principals = principals_allowed_by_permission(text_topic, "view")
assert Everyone in principals
@ -15,71 +11,70 @@ def test_deleted_topic_permissions_removed(topic):
"""Ensure that deleted topics lose all permissions except "view"."""
topic.is_deleted = True
assert principals_allowed_by_permission(topic, 'view') == {Everyone}
assert principals_allowed_by_permission(topic, "view") == {Everyone}
all_permissions = [
perm for (_, _, perm) in topic.__acl__() if perm != 'view']
all_permissions = [perm for (_, _, perm) in topic.__acl__() if perm != "view"]
for permission in all_permissions:
assert not principals_allowed_by_permission(topic, permission)
def test_text_topic_editing_permission(text_topic):
"""Ensure a text topic's owner (and nobody else) is able to edit it."""
principals = principals_allowed_by_permission(text_topic, 'edit')
principals = principals_allowed_by_permission(text_topic, "edit")
assert principals == {text_topic.user.user_id}
def test_link_topic_editing_permission(link_topic):
"""Ensure that nobody has edit permission on a link topic."""
principals = principals_allowed_by_permission(link_topic, 'edit')
principals = principals_allowed_by_permission(link_topic, "edit")
assert not principals
def test_topic_deleting_permission(text_topic):
"""Ensure that the topic's owner (and nobody else) is able to delete it."""
principals = principals_allowed_by_permission(text_topic, 'delete')
principals = principals_allowed_by_permission(text_topic, "delete")
assert principals == {text_topic.user.user_id}
def test_topic_view_author_permission(text_topic):
"""Ensure anyone can view a topic's author normally."""
principals = principals_allowed_by_permission(text_topic, 'view_author')
principals = principals_allowed_by_permission(text_topic, "view_author")
assert Everyone in principals
def test_removed_topic_view_author_permission(topic):
"""Ensure only admins and the author can view a removed topic's author."""
topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_author')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_author")
assert principals == {"admin", topic.user_id}
def test_topic_view_content_permission(text_topic):
"""Ensure anyone can view a topic's content normally."""
principals = principals_allowed_by_permission(text_topic, 'view_content')
principals = principals_allowed_by_permission(text_topic, "view_content")
assert Everyone in principals
def test_removed_topic_view_content_permission(topic):
"""Ensure only admins and the author can view a removed topic's content."""
topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_content')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_content")
assert principals == {"admin", topic.user_id}
def test_topic_comment_permission(text_topic):
"""Ensure authed users have comment perms on a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'comment')
principals = principals_allowed_by_permission(text_topic, "comment")
assert Authenticated in principals
def test_locked_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on locked topics."""
topic.is_locked = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}
def test_removed_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on removed topics."""
topic.is_removed = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}

22
tildes/tests/test_topic_tags.py

@ -1,34 +1,34 @@
def test_tags_whitespace_stripped(text_topic):
"""Ensure excess whitespace around tags gets stripped."""
text_topic.tags = [' one', 'two ', ' three ']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = [" one", "two ", " three "]
assert text_topic.tags == ["one", "two", "three"]
def test_tag_space_replacement(text_topic):
"""Ensure spaces in tags are converted to underscores internally."""
text_topic.tags = ['one two', 'three four five']
assert text_topic._tags == ['one_two', 'three_four_five']
text_topic.tags = ["one two", "three four five"]
assert text_topic._tags == ["one_two", "three_four_five"]
def test_tag_consecutive_spaces(text_topic):
"""Ensure consecutive spaces/underscores in tags are removed."""
text_topic.tags = ["one two", "three four", "five __ six"]
assert text_topic.tags == ['one two', 'three four', 'five six']
assert text_topic.tags == ["one two", "three four", "five six"]
def test_duplicate_tags_removed(text_topic):
"""Ensure duplicate tags are removed (case-insensitive)."""
text_topic.tags = ['one', 'one', 'One', 'ONE', 'two', 'TWO']
assert text_topic.tags == ['one', 'two']
text_topic.tags = ["one", "one", "One", "ONE", "two", "TWO"]
assert text_topic.tags == ["one", "two"]
def test_empty_tags_removed(text_topic):
"""Ensure empty tags are removed."""
text_topic.tags = ['', ' ', '_', 'one']
assert text_topic.tags == ['one']
text_topic.tags = ["", " ", "_", "one"]
assert text_topic.tags == ["one"]
def test_tags_lowercased(text_topic):
"""Ensure tags get converted to lowercase."""
text_topic.tags = ['ONE', 'Two', 'thRee']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = ["ONE", "Two", "thRee"]
assert text_topic.tags == ["one", "two", "three"]

6
tildes/tests/test_triggers_comments.py

@ -8,7 +8,7 @@ def test_comments_affect_topic_num_comments(session_user, topic, db):
# Insert some comments, ensure each one increments the count
comments = []
for num in range(0, 5):
new_comment = Comment(topic, session_user, 'comment')
new_comment = Comment(topic, session_user, "comment")
comments.append(new_comment)
db.add(new_comment)
db.commit()
@ -62,8 +62,8 @@ def test_remove_sets_removed_time(db, comment):
def test_remove_delete_single_decrement(db, topic, session_user):
"""Ensure that remove+delete doesn't double-decrement num_comments."""
# add 2 comments
comment1 = Comment(topic, session_user, 'Comment 1')
comment2 = Comment(topic, session_user, 'Comment 2')
comment1 = Comment(topic, session_user, "Comment 1")
comment2 = Comment(topic, session_user, "Comment 2")
db.add_all([comment1, comment2])
db.commit()
db.refresh(topic)

22
tildes/tests/test_url.py

@ -5,13 +5,13 @@ from tildes.lib.url import get_domain_from_url
def test_simple_get_domain():
"""Ensure getting the domain from a normal URL works."""
url = 'http://example.com/some/path?query=param&query2=val2'
assert get_domain_from_url(url) == 'example.com'
url = "http://example.com/some/path?query=param&query2=val2"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_non_url():
"""Ensure attempting to get the domain for a non-url is an error."""
url = 'this is not a url'
url = "this is not a url"
with raises(ValueError):
get_domain_from_url(url)
@ -19,27 +19,27 @@ def test_get_domain_non_url():
def test_get_domain_no_scheme():
"""Ensure getting domain on a url with no scheme is an error."""
with raises(ValueError):
get_domain_from_url('example.com/something')
get_domain_from_url("example.com/something")
def test_get_domain_explicit_no_scheme():
"""Ensure getting domain works if url is explicit about lack of scheme."""
assert get_domain_from_url('//example.com/something') == 'example.com'
assert get_domain_from_url("//example.com/something") == "example.com"
def test_get_domain_strip_www():
"""Ensure stripping the "www." from the domain works as expected."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url) == 'example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_no_strip_www():
"""Ensure stripping the "www." can be disabled."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url, strip_www=False) == 'www.example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url, strip_www=False) == "www.example.com"
def test_get_domain_subdomain_not_stripped():
"""Ensure a non-www subdomain isn't stripped."""
url = 'http://something.example.com/path/x/y/z'
assert get_domain_from_url(url) == 'something.example.com'
url = "http://something.example.com/path/x/y/z"
assert get_domain_from_url(url) == "something.example.com"

51
tildes/tests/test_user.py

@ -8,55 +8,55 @@ from tildes.schemas.user import PASSWORD_MIN_LENGTH, UserSchema
def test_creation_validates_schema(mocker):
"""Ensure that model creation goes through schema validation."""
mocker.spy(UserSchema, 'validate')
User('testing', 'testpassword')
mocker.spy(UserSchema, "validate")
User("testing", "testpassword")
call_args = [call[0] for call in UserSchema.validate.call_args_list]
expected_args = {'username': 'testing', 'password': 'testpassword'}
expected_args = {"username": "testing", "password": "testpassword"}
assert any(expected_args in call for call in call_args)
def test_too_short_password():
"""Ensure a new user can't be created with a too-short password."""
password = 'x' * (PASSWORD_MIN_LENGTH - 1)
password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError):
User('ShortPasswordGuy', password)
User("ShortPasswordGuy", password)
def test_matching_password_and_username():
"""Ensure a new user can't be created with same username and password."""
with raises(ValidationError):
User('UnimaginativePassword', 'UnimaginativePassword')
User("UnimaginativePassword", "UnimaginativePassword")
def test_username_and_password_differ_in_casing():
"""Ensure a user can't be created with name/pass the same except case."""
with raises(ValidationError):
User('NobodyWillGuess', 'nobodywillguess')
User("NobodyWillGuess", "nobodywillguess")
def test_username_contained_in_password():
"""Ensure a user can't be created with the username in the password."""
with raises(ValidationError):
User('MyUsername', 'iputmyusernameinmypassword')
User("MyUsername", "iputmyusernameinmypassword")
def test_password_contained_in_username():
"""Ensure a user can't be created with the password in the username."""
with raises(ValidationError):
User('PasswordIsVeryGood', 'VeryGood')
User("PasswordIsVeryGood", "VeryGood")
def test_user_password_check():
"""Ensure checking the password for a new user works correctly."""
new_user = User('myusername', 'mypassword')
assert new_user.is_correct_password('mypassword')
new_user = User("myusername", "mypassword")
assert new_user.is_correct_password("mypassword")
def test_duplicate_username(db):
"""Ensure two users with the same name can't be created."""
original = User('Inimitable', 'securepassword')
original = User("Inimitable", "securepassword")
db.add(original)
duplicate = User('Inimitable', 'adifferentpassword')
duplicate = User("Inimitable", "adifferentpassword")
db.add(duplicate)
with raises(IntegrityError):
@ -65,10 +65,10 @@ def test_duplicate_username(db):
def test_duplicate_username_case_insensitive(db):
"""Ensure usernames only differing in casing can't be created."""
test_username = 'test_user'
original = User(test_username.lower(), 'hackproof')
test_username = "test_user"
original = User(test_username.lower(), "hackproof")
db.add(original)
duplicate = User(test_username.upper(), 'sosecure')
duplicate = User(test_username.upper(), "sosecure")
db.add(duplicate)
with raises(IntegrityError):
@ -77,20 +77,20 @@ def test_duplicate_username_case_insensitive(db):
def test_change_password():
"""Ensure changing a user password works as expected."""
new_user = User('A_New_User', 'lovesexsecretgod')
new_user = User("A_New_User", "lovesexsecretgod")
new_user.change_password('lovesexsecretgod', 'lovesexsecretgod1')
new_user.change_password("lovesexsecretgod", "lovesexsecretgod1")
# the old one shouldn't work
assert not new_user.is_correct_password('lovesexsecretgod')
assert not new_user.is_correct_password("lovesexsecretgod")
# the new one should
assert new_user.is_correct_password('lovesexsecretgod1')
assert new_user.is_correct_password("lovesexsecretgod1")
def test_change_password_to_same(session_user):
"""Ensure users can't "change" to the same password."""
password = 'session user password'
password = "session user password"
with raises(ValueError):
session_user.change_password(password, password)
@ -98,18 +98,17 @@ def test_change_password_to_same(session_user):
def test_change_password_wrong_old_one(session_user):
"""Ensure changing password doesn't work if the old one is wrong."""
with raises(ValueError):
session_user.change_password('definitely not right', 'some new one')
session_user.change_password("definitely not right", "some new one")
def test_change_password_too_short(session_user):
"""Ensure users can't change password to a too-short one."""
new_password = 'x' * (PASSWORD_MIN_LENGTH - 1)
new_password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError):
session_user.change_password('session user password', new_password)
session_user.change_password("session user password", new_password)
def test_change_password_to_username(session_user):
"""Ensure users can't change password to the same as their username."""
with raises(ValidationError):
session_user.change_password(
'session user password', session_user.username)
session_user.change_password("session user password", session_user.username)

16
tildes/tests/test_username.py

@ -10,7 +10,7 @@ from tildes.schemas.user import (
def test_too_short_invalid():
"""Ensure too-short username is invalid."""
length = USERNAME_MIN_LENGTH - 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username)
@ -18,7 +18,7 @@ def test_too_short_invalid():
def test_too_long_invalid():
"""Ensure too-long username is invalid."""
length = USERNAME_MAX_LENGTH + 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username)
@ -26,22 +26,22 @@ def test_too_long_invalid():
def test_valid_length_range():
"""Ensure the entire range of valid lengths work."""
for length in range(USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH + 1):
username = 'x' * length
username = "x" * length
assert is_valid_username(username)
def test_consecutive_spacer_chars_invalid():
"""Ensure that a username with consecutive "spacer chars" is invalid."""
spacer_chars = '_-'
spacer_chars = "_-"
for char1, char2 in product(spacer_chars, spacer_chars):
username = f'abc{char1}{char2}xyz'
username = f"abc{char1}{char2}xyz"
assert not is_valid_username(username)
def test_typical_username_valid():
"""Ensure a "normal-looking" username is considered valid."""
assert is_valid_username('someTypical_user-85')
assert is_valid_username("someTypical_user-85")
def test_invalid_characters():
@ -49,11 +49,11 @@ def test_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,.?/'
for char in invalid_chars:
username = f'abc{char}xyz'
username = f"abc{char}xyz"
assert not is_valid_username(username)
def test_unicode_characters():
"""Ensure that unicode chars can't be included (not comprehensive)."""
for username in ('pokémon', 'ポケモン', 'møøse'):
for username in ("pokémon", "ポケモン", "møøse"):
assert not is_valid_username(username)

10
tildes/tests/test_webassets.py

@ -1,22 +1,22 @@
from webassets.loaders import YAMLLoader
WEBASSETS_ENV = YAMLLoader('webassets.yaml').load_environment()
WEBASSETS_ENV = YAMLLoader("webassets.yaml").load_environment()
def test_scripts_file_first_in_bundle():
"""Ensure that the main scripts.js file will be at the top."""
js_bundle = WEBASSETS_ENV['javascript']
js_bundle = WEBASSETS_ENV["javascript"]
first_filename = js_bundle.resolve_contents()[0][0]
assert first_filename == 'js/scripts.js'
assert first_filename == "js/scripts.js"
def test_styles_file_last_in_bundle():
"""Ensure that the main styles.css file will be at the bottom."""
css_bundle = WEBASSETS_ENV['css']
css_bundle = WEBASSETS_ENV["css"]
last_filename = css_bundle.resolve_contents()[-1][0]
assert last_filename == 'css/styles.css'
assert last_filename == "css/styles.css"

6
tildes/tests/webtests/test_user_page.py

@ -6,8 +6,10 @@ def test_loggedout_username_leak(webtest_loggedout, session_user):
particular username exists or not.
"""
existing_user = webtest_loggedout.get(
'/user/' + session_user.username, expect_errors=True)
"/user/" + session_user.username, expect_errors=True
)
nonexistent_user = webtest_loggedout.get(
'/user/thisdoesntexist', expect_errors=True)
"/user/thisdoesntexist", expect_errors=True
)
assert existing_user.status == nonexistent_user.status

93
tildes/tildes/__init__.py

@ -16,51 +16,47 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
"""Configure and return a Pyramid WSGI application."""
config = Configurator(settings=settings)
config.include('cornice')
config.include('pyramid_session_redis')
config.include('pyramid_webassets')
config.include("cornice")
config.include("pyramid_session_redis")
config.include("pyramid_webassets")
# include database first so the session and querying are available
config.include('tildes.database')
config.include('tildes.auth')
config.include('tildes.jinja')
config.include('tildes.json')
config.include('tildes.routes')
config.include("tildes.database")
config.include("tildes.auth")
config.include("tildes.jinja")
config.include("tildes.json")
config.include("tildes.routes")
config.add_webasset('javascript', Bundle(output='js/tildes.js'))
config.add_webasset(
'javascript-third-party', Bundle(output='js/third_party.js'))
config.add_webasset('css', Bundle(output='css/tildes.css'))
config.add_webasset('site-icons-css', Bundle(output='css/site-icons.css'))
config.add_webasset("javascript", Bundle(output="js/tildes.js"))
config.add_webasset("javascript-third-party", Bundle(output="js/third_party.js"))
config.add_webasset("css", Bundle(output="css/tildes.css"))
config.add_webasset("site-icons-css", Bundle(output="css/site-icons.css"))
config.scan('tildes.views')
config.scan("tildes.views")
config.add_tween('tildes.http_method_tween_factory')
config.add_tween("tildes.http_method_tween_factory")
config.add_request_method(
is_safe_request_method, 'is_safe_method', reify=True)
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True)
# Add the request.redis request method to access a redis connection. This
# is done in a bit of a strange way to support being overridden in tests.
config.registry['redis_connection_factory'] = get_redis_connection
config.registry["redis_connection_factory"] = get_redis_connection
# pylint: disable=unnecessary-lambda
config.add_request_method(
lambda request: config.registry['redis_connection_factory'](request),
'redis',
lambda request: config.registry["redis_connection_factory"](request),
"redis",
reify=True,
)
# pylint: enable=unnecessary-lambda
config.add_request_method(check_rate_limit, 'check_rate_limit')
config.add_request_method(check_rate_limit, "check_rate_limit")
config.add_request_method(
current_listing_base_url, 'current_listing_base_url')
config.add_request_method(
current_listing_normal_url, 'current_listing_normal_url')
config.add_request_method(current_listing_base_url, "current_listing_base_url")
config.add_request_method(current_listing_normal_url, "current_listing_normal_url")
app = config.make_wsgi_app()
force_port = global_config.get('prefixmiddleware_force_port')
force_port = global_config.get("prefixmiddleware_force_port")
if force_port:
prefixed_app = PrefixMiddleware(app, force_port=force_port)
else:
@ -69,19 +65,17 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
return prefixed_app
def http_method_tween_factory(
handler: Callable,
registry: Registry,
) -> Callable:
def http_method_tween_factory(handler: Callable, registry: Registry) -> Callable:
# pylint: disable=unused-argument
"""Return a tween function that can override the request's HTTP method."""
def method_override_tween(request: Request) -> Request:
"""Override HTTP method with one specified in header."""
valid_overrides_by_method = {'POST': ['DELETE', 'PATCH', 'PUT']}
valid_overrides_by_method = {"POST": ["DELETE", "PATCH", "PUT"]}
original_method = request.method.upper()
valid_overrides = valid_overrides_by_method.get(original_method, [])
override = request.headers.get('X-HTTP-Method-Override', '').upper()
override = request.headers.get("X-HTTP-Method-Override", "").upper()
if override in valid_overrides:
request.method = override
@ -93,13 +87,13 @@ def http_method_tween_factory(
def get_redis_connection(request: Request) -> StrictRedis:
"""Return a StrictRedis connection to the Redis server."""
socket = request.registry.settings['redis.unix_socket_path']
socket = request.registry.settings["redis.unix_socket_path"]
return StrictRedis(unix_socket_path=socket)
def is_safe_request_method(request: Request) -> bool:
"""Return whether the request method is "safe" (is GET or HEAD)."""
return request.method in {'GET', 'HEAD'}
return request.method in {"GET", "HEAD"}
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
@ -107,7 +101,7 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
try:
action = RATE_LIMITED_ACTIONS[action_name]
except KeyError:
raise ValueError('Invalid action name: %s' % action_name)
raise ValueError("Invalid action name: %s" % action_name)
action.redis = request.redis
@ -127,8 +121,7 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
def current_listing_base_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "base" url for the current listing route.
@ -137,14 +130,12 @@ def current_listing_base_url(
The `query` argument allows adding query variables to the generated url.
"""
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
base_view_vars = (
'order', 'period', 'per_page', 'tag', 'type', 'unfiltered')
base_view_vars = ("order", "period", "per_page", "tag", "type", "unfiltered")
query_vars = {
key: val for key, val in request.GET.copy().items()
if key in base_view_vars
key: val for key, val in request.GET.copy().items() if key in base_view_vars
}
if query:
query_vars.update(query)
@ -152,12 +143,11 @@ def current_listing_base_url(
url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")
def current_listing_normal_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "normal" url for the current listing route.
@ -166,13 +156,12 @@ def current_listing_normal_url(
The `query` argument allows adding query variables to the generated url.
"""
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
normal_view_vars = ('order', 'period', 'per_page')
normal_view_vars = ("order", "period", "per_page")
query_vars = {
key: val for key, val in request.GET.copy().items()
if key in normal_view_vars
key: val for key, val in request.GET.copy().items() if key in normal_view_vars
}
if query:
query_vars.update(query)
@ -180,4 +169,4 @@ def current_listing_normal_url(
url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")

6
tildes/tildes/api.py

@ -9,8 +9,8 @@ import venusian
class APIv0(Service):
"""Service wrapper class for v0 of the API."""
name_prefix = 'apiv0_'
base_path = '/api/v0'
name_prefix = "apiv0_"
base_path = "/api/v0"
def __init__(self, name: str, path: str, **kwargs: Any) -> None:
"""Create a new service."""
@ -28,4 +28,4 @@ class APIv0(Service):
# TEMP: disable API until I can fix the private-fields issue
# config.add_cornice_service(self)
info = venusian.attach(self, callback, category='pyramid')
info = venusian.attach(self, callback, category="pyramid")

46
tildes/tildes/auth.py

@ -7,13 +7,7 @@ from pyramid.authorization import ACLAuthorizationPolicy
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPFound
from pyramid.request import Request
from pyramid.security import (
ACLDenied,
ACLPermitsResult,
Allow,
Authenticated,
Everyone,
)
from pyramid.security import ACLDenied, ACLPermitsResult, Allow, Authenticated, Everyone
from tildes.models.user import User
@ -27,7 +21,7 @@ class DefaultRootFactory:
an __acl__ defined, they will not "fall back" to this one.
"""
__acl__ = ((Allow, Everyone, 'view'),)
__acl__ = ((Allow, Everyone, "view"),)
def __init__(self, request: Request) -> None:
"""Root factory constructor - must take a request argument."""
@ -40,10 +34,7 @@ def get_authenticated_user(request: Request) -> Optional[User]:
if not user_id:
return None
query = (
request.query(User)
.filter_by(user_id=user_id)
)
query = request.query(User).filter_by(user_id=user_id)
return query.one_or_none()
@ -60,15 +51,15 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
# if the user is banned, log them out - is there a better place to do this?
if request.user.is_banned:
request.session.invalidate()
raise HTTPFound('/')
raise HTTPFound("/")
if user_id != request.user.user_id:
raise AssertionError('auth_callback called with different user_id')
raise AssertionError("auth_callback called with different user_id")
principals = []
if request.user.is_admin:
principals.append('admin')
principals.append("admin")
return principals
@ -76,7 +67,7 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
def includeme(config: Configurator) -> None:
"""Config updates related to authentication/authorization."""
# make all views require "view" permission unless specifically overridden
config.set_default_permission('view')
config.set_default_permission("view")
# replace the default root factory with a custom one to more easily support
# the default permission
@ -89,32 +80,30 @@ def includeme(config: Configurator) -> None:
config.set_authorization_policy(AuthorizedOnlyPolicy())
config.set_authentication_policy(
SessionAuthenticationPolicy(callback=auth_callback))
SessionAuthenticationPolicy(callback=auth_callback)
)
# enable CSRF checking globally by default
config.set_default_csrf_options(require_csrf=True)
# make the logged-in User object available as request.user
config.add_request_method(get_authenticated_user, 'user', reify=True)
config.add_request_method(get_authenticated_user, "user", reify=True)
# add has_any_permission method for easily checking multiple permissions
config.add_request_method(has_any_permission, 'has_any_permission')
config.add_request_method(has_any_permission, "has_any_permission")
class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
"""ACLAuthorizationPolicy override that always denies logged-out users."""
def permits(
self,
context: Any,
principals: Sequence[Any],
permission: str,
self, context: Any, principals: Sequence[Any], permission: str
) -> ACLPermitsResult:
"""Deny logged-out users, otherwise pass up to normal policy."""
if Authenticated not in principals:
return ACLDenied(
'<authorized only>',
'<no ACLs checked yet>',
"<authorized only>",
"<no ACLs checked yet>",
permission,
principals,
context,
@ -124,12 +113,9 @@ class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
def has_any_permission(
request: Request,
permissions: Sequence[str],
context: Any,
request: Request, permissions: Sequence[str], context: Any
) -> bool:
"""Return whether the user has any of the permissions on the item."""
return any(
request.has_permission(permission, context)
for permission in permissions
request.has_permission(permission, context) for permission in permissions
)

27
tildes/tildes/database.py

@ -28,10 +28,7 @@ def obtain_lock(request: Request, lock_space: str, lock_value: int) -> None:
obtain_transaction_lock(request.db_session, lock_space, lock_value)
def query_factory(
request: Request,
model_cls: Type[DatabaseModel],
) -> ModelQuery:
def query_factory(request: Request, model_cls: Type[DatabaseModel]) -> ModelQuery:
"""Return a ModelQuery or subclass depending on model_cls specified."""
if model_cls == Comment:
return CommentQuery(request)
@ -46,8 +43,7 @@ def query_factory(
def get_tm_session(
session_factory: Callable,
transaction_manager: ThreadTransactionManager,
session_factory: Callable, transaction_manager: ThreadTransactionManager
) -> Session:
"""Return a db session being managed by the transaction manager."""
db_session = session_factory()
@ -74,26 +70,27 @@ def includeme(config: Configurator) -> None:
# transaction if the response code starts with 4 or 5. The main benefit of
# this is to avoid aborting on exceptions that don't actually indicate a
# problem, such as a HTTPFound 302 redirect.
settings['tm.commit_veto'] = 'pyramid_tm.default_commit_veto'
settings["tm.commit_veto"] = "pyramid_tm.default_commit_veto"
config.include('pyramid_tm')
config.include("pyramid_tm")
# disable SQLAlchemy connection pooling since pgbouncer will handle it
settings['sqlalchemy.poolclass'] = NullPool
settings["sqlalchemy.poolclass"] = NullPool
engine = engine_from_config(settings, 'sqlalchemy.')
engine = engine_from_config(settings, "sqlalchemy.")
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
config.registry['db_session_factory'] = session_factory
config.registry["db_session_factory"] = session_factory
# attach the session to each request as request.db_session
config.add_request_method(
lambda request: get_tm_session(
config.registry['db_session_factory'], request.tm),
'db_session',
config.registry["db_session_factory"], request.tm
),
"db_session",
reify=True,
)
config.add_request_method(query_factory, 'query')
config.add_request_method(query_factory, "query")
config.add_request_method(obtain_lock, 'obtain_lock')
config.add_request_method(obtain_lock, "obtain_lock")

20
tildes/tildes/enums.py

@ -21,12 +21,12 @@ class CommentSortOption(enum.Enum):
@property
def description(self) -> str:
"""Describe this sort option."""
if self.name == 'NEWEST':
return 'newest first'
elif self.name == 'POSTED':
return 'order posted'
if self.name == "NEWEST":
return "newest first"
elif self.name == "POSTED":
return "order posted"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class CommentTagOption(enum.Enum):
@ -72,12 +72,12 @@ class TopicSortOption(enum.Enum):
using that sort in descending order means that topics with the most
votes will be listed first.
"""
if self.name == 'NEW':
return 'newest'
elif self.name == 'ACTIVITY':
return 'activity'
if self.name == "NEW":
return "newest"
elif self.name == "ACTIVITY":
return "activity"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class TopicType(enum.Enum):

26
tildes/tildes/jinja.py

@ -29,28 +29,26 @@ def includeme(config: Configurator) -> None:
"""Configure Jinja2 template renderer."""
settings = config.get_settings()
settings['jinja2.lstrip_blocks'] = True
settings['jinja2.trim_blocks'] = True
settings['jinja2.undefined'] = 'strict'
settings["jinja2.lstrip_blocks"] = True
settings["jinja2.trim_blocks"] = True
settings["jinja2.undefined"] = "strict"
# add custom jinja filters
settings['jinja2.filters'] = {
'ago': descriptive_timedelta,
}
settings["jinja2.filters"] = {"ago": descriptive_timedelta}
# add custom jinja tests
settings['jinja2.tests'] = {
'comment': is_comment,
'group': is_group,
'topic': is_topic,
settings["jinja2.tests"] = {
"comment": is_comment,
"group": is_group,
"topic": is_topic,
}
config.include('pyramid_jinja2')
config.include("pyramid_jinja2")
config.add_jinja2_search_path('tildes:templates/')
config.add_jinja2_search_path("tildes:templates/")
config.add_jinja2_extension('jinja2.ext.do')
config.add_jinja2_extension('webassets.ext.jinja2.AssetsExtension')
config.add_jinja2_extension("jinja2.ext.do")
config.add_jinja2_extension("webassets.ext.jinja2.AssetsExtension")
# attach webassets to jinja2 environment (via scheduled action)
def attach_webassets_to_jinja2() -> None:

6
tildes/tildes/json.py

@ -23,8 +23,8 @@ def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
def serialize_topic(topic: Topic, request: Request) -> dict:
"""Return serializable data for a Topic."""
context = {}
if not request.has_permission('view_author', topic):
context['hide_username'] = True
if not request.has_permission("view_author", topic):
context["hide_username"] = True
return topic.schema_class(context=context).dump(topic)
@ -40,4 +40,4 @@ def includeme(config: Configurator) -> None:
# add specific adapters
json_renderer.add_adapter(Topic, serialize_topic)
config.add_renderer('json', json_renderer)
config.add_renderer("json", json_renderer)

16
tildes/tildes/lib/amqp.py

@ -24,30 +24,24 @@ class PgsqlQueueConsumer(AbstractConsumer):
JSON format.
"""
PGSQL_EXCHANGE_NAME = 'pgsql_events'
PGSQL_EXCHANGE_NAME = "pgsql_events"
def __init__(
self,
queue_name: str,
routing_keys: Sequence[str],
uses_db: bool = True,
self, queue_name: str, routing_keys: Sequence[str], uses_db: bool = True
) -> None:
"""Initialize a new queue, bindings, and consumer for it."""
self.connection = Connection()
self.channel = self.connection.channel()
self.channel.queue_declare(
queue_name, durable=True, auto_delete=False)
self.channel.queue_declare(queue_name, durable=True, auto_delete=False)
for routing_key in routing_keys:
self.channel.queue_bind(
queue_name,
exchange=self.PGSQL_EXCHANGE_NAME,
routing_key=routing_key,
queue_name, exchange=self.PGSQL_EXCHANGE_NAME, routing_key=routing_key
)
if uses_db:
self.db_session = get_session_from_config(os.environ['INI_FILE'])
self.db_session = get_session_from_config(os.environ["INI_FILE"])
else:
self.db_session = None

12
tildes/tildes/lib/cmark.py

@ -4,14 +4,14 @@
from ctypes import CDLL, c_char_p, c_int, c_size_t, c_void_p
CMARK_DLL = CDLL('/usr/local/lib/libcmark-gfm.so')
CMARK_EXT_DLL = CDLL('/usr/local/lib/libcmark-gfmextensions.so')
CMARK_DLL = CDLL("/usr/local/lib/libcmark-gfm.so")
CMARK_EXT_DLL = CDLL("/usr/local/lib/libcmark-gfmextensions.so")
# enables the --hardbreaks option for cmark
# (can I import this? it's defined in cmark.h as CMARK_OPT_HARDBREAKS)
CMARK_OPTS = 4
CMARK_EXTENSIONS = (b'strikethrough', b'table')
CMARK_EXTENSIONS = (b"strikethrough", b"table")
cmark_parser_new = CMARK_DLL.cmark_parser_new
cmark_parser_new.restype = c_void_p
@ -25,13 +25,11 @@ cmark_parser_finish = CMARK_DLL.cmark_parser_finish
cmark_parser_finish.restype = c_void_p
cmark_parser_finish.argtypes = (c_void_p,)
cmark_parser_attach_syntax_extension = (
CMARK_DLL.cmark_parser_attach_syntax_extension)
cmark_parser_attach_syntax_extension = CMARK_DLL.cmark_parser_attach_syntax_extension
cmark_parser_attach_syntax_extension.restype = c_int
cmark_parser_attach_syntax_extension.argtypes = (c_void_p, c_void_p)
cmark_parser_get_syntax_extensions = (
CMARK_DLL.cmark_parser_get_syntax_extensions)
cmark_parser_get_syntax_extensions = CMARK_DLL.cmark_parser_get_syntax_extensions
cmark_parser_get_syntax_extensions.restype = c_void_p
cmark_parser_get_syntax_extensions.argtypes = (c_void_p,)

26
tildes/tildes/lib/database.py

@ -20,7 +20,7 @@ NOT_NULL_ERROR_CODE = 23502
def get_session_from_config(config_path: str) -> Session:
"""Get a database session from a config file (specified by path)."""
env = bootstrap(config_path)
session_factory = env['registry']['db_session_factory']
session_factory = env["registry"]["db_session_factory"]
return session_factory()
@ -31,9 +31,7 @@ class LockSpaces(enum.Enum):
def obtain_transaction_lock(
session: Session,
lock_space: Optional[str],
lock_value: int,
session: Session, lock_space: Optional[str], lock_value: int
) -> None:
"""Obtain a transaction-level advisory lock from PostgreSQL.
@ -45,11 +43,9 @@ def obtain_transaction_lock(
try:
lock_space_value = LockSpaces[lock_space.upper()].value
except KeyError:
raise ValueError('Invalid lock space: %s' % lock_space)
raise ValueError("Invalid lock space: %s" % lock_space)
session.query(
func.pg_advisory_xact_lock(lock_space_value, lock_value)
).one()
session.query(func.pg_advisory_xact_lock(lock_space_value, lock_value)).one()
else:
session.query(func.pg_advisory_xact_lock(lock_value)).one()
@ -66,10 +62,11 @@ class CIText(UserDefinedType):
def get_col_spec(self, **kw: Any) -> str:
"""Return the type name (for creating columns and so on)."""
# pylint: disable=no-self-use,unused-argument
return 'CITEXT'
return "CITEXT"
def bind_processor(self, dialect: Dialect) -> Callable:
"""Return a conversion function for processing bind values."""
def process(value: Any) -> Any:
return value
@ -77,6 +74,7 @@ class CIText(UserDefinedType):
def result_processor(self, dialect: Dialect, coltype: Any) -> Callable:
"""Return a conversion function for processing result row values."""
def process(value: Any) -> Any:
return value
@ -103,8 +101,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
super_rp = super().result_processor(dialect, coltype)
def handle_raw_string(value: str) -> List[str]:
if not (value.startswith('{') and value.endswith('}')):
raise ValueError('%s is not an array value' % value)
if not (value.startswith("{") and value.endswith("}")):
raise ValueError("%s is not an array value" % value)
# trim off the surrounding braces
value = value[1:-1]
@ -113,7 +111,7 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
if not value:
return []
return value.split(',')
return value.split(",")
def process(value: Optional[str]) -> Optional[List[str]]:
if value is None:
@ -133,8 +131,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
def ancestor_of(self, other): # type: ignore
"""Return whether the array contains any ancestor of `other`."""
return self.op('@>')(other)
return self.op("@>")(other)
def descendant_of(self, other): # type: ignore
"""Return whether the array contains any descendant of `other`."""
return self.op('<@')(other)
return self.op("<@")(other)

30
tildes/tildes/lib/datetime.py

@ -10,32 +10,32 @@ from ago import human
class SimpleHoursPeriod:
"""A simple class that represents a time period of hours or days."""
_SHORT_FORM_REGEX = re.compile(r'\d+[hd]', re.IGNORECASE)
_SHORT_FORM_REGEX = re.compile(r"\d+[hd]", re.IGNORECASE)
def __init__(self, hours: int) -> None:
"""Initialize a SimpleHoursPeriod from a number of hours."""
if hours <= 0:
raise ValueError('Period must be at least 1 hour.')
raise ValueError("Period must be at least 1 hour.")
self.hours = hours
try:
self.timedelta = timedelta(hours=hours)
except OverflowError:
raise ValueError('Time period is too large')
raise ValueError("Time period is too large")
@classmethod
def from_short_form(cls, short_form: str) -> 'SimpleHoursPeriod':
def from_short_form(cls, short_form: str) -> "SimpleHoursPeriod":
"""Initialize a period from a "short form" string (e.g. "2h", "4d")."""
if not cls._SHORT_FORM_REGEX.match(short_form):
raise ValueError('Invalid time period')
raise ValueError("Invalid time period")
unit = short_form[-1].lower()
count = int(short_form[:-1])
if unit == 'h':
if unit == "h":
hours = count
elif unit == 'd':
elif unit == "d":
hours = count * 24
return cls(hours=hours)
@ -47,9 +47,9 @@ class SimpleHoursPeriod:
for the special case of exactly "1 day", which is replaced with "24
hours".
"""
string = human(self.timedelta, past_tense='{}')
if string == '1 day':
string = '24 hours'
string = human(self.timedelta, past_tense="{}")
if string == "1 day":
string = "24 hours"
return string
@ -67,9 +67,9 @@ class SimpleHoursPeriod:
24 hours (except for 24 hours itself).
"""
if self.hours % 24 == 0 and self.hours != 24:
return '{}d'.format(self.hours // 24)
return "{}d".format(self.hours // 24)
return f'{self.hours}h'
return f"{self.hours}h"
def utc_now() -> datetime:
@ -93,7 +93,7 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
"""
seconds_ago = (utc_now() - target).total_seconds()
if seconds_ago < 1:
return 'a moment ago'
return "a moment ago"
# determine whether one or two precision levels is appropriate
if seconds_ago < 3600:
@ -103,7 +103,7 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# try a precision=2 version, and check the units it ends up with
result = human(target, precision=2)
units = ('year', 'day', 'hour', 'minute', 'second')
units = ("year", "day", "hour", "minute", "second")
unit_indices = [i for (i, unit) in enumerate(units) if unit in result]
# if there was only one unit in it, or they're adjacent, this is fine
@ -117,6 +117,6 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# remove commas if abbreviating ("3d 2h ago", not "3d, 2h ago")
if abbreviate:
result = result.replace(',', '')
result = result.replace(",", "")
return result

3
tildes/tildes/lib/hash.py

@ -11,7 +11,8 @@ ARGON2_TIME_COST = 4
ARGON2_MEMORY_COST = 8092
ARGON2_HASHER = PasswordHasher(
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST)
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST
)
def hash_string(string: str) -> str:

10
tildes/tildes/lib/id.py

@ -4,13 +4,13 @@ import re
import string
ID36_REGEX = re.compile('^[a-z0-9]+$', re.IGNORECASE)
ID36_REGEX = re.compile("^[a-z0-9]+$", re.IGNORECASE)
def id_to_id36(id_val: int) -> str:
"""Convert an integer ID to the string ID36 representation."""
if id_val < 1:
raise ValueError('ID values should never be zero or negative')
raise ValueError("ID values should never be zero or negative")
reversed_chars = []
@ -29,13 +29,13 @@ def id_to_id36(id_val: int) -> str:
reversed_chars.append(alphabet[index])
# join the characters in reversed order and return as the result
return ''.join(reversed(reversed_chars))
return "".join(reversed(reversed_chars))
def id36_to_id(id36_val: str) -> int:
"""Convert a string ID36 to the integer ID representation."""
if id36_val.startswith('-') or id36_val == '0':
raise ValueError('ID values should never be zero or negative')
if id36_val.startswith("-") or id36_val == "0":
raise ValueError("ID values should never be zero or negative")
# Python's stdlib can handle this, much simpler in this direction
return int(id36_val, 36)

181
tildes/tildes/lib/markdown.py

@ -40,51 +40,51 @@ from .cmark import (
HTML_TAG_WHITELIST = (
'a',
'b',
'blockquote',
'br',
'code',
'del',
'em',
'h1',
'h2',
'h3',
'h4',
'h5',
'h6',
'hr',
'i',
'ins',
'li',
'ol',
'p',
'pre',
'strong',
'sub',
'sup',
'table',
'tbody',
'td',
'th',
'thead',
'tr',
'ul',
"a",
"b",
"blockquote",
"br",
"code",
"del",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"hr",
"i",
"ins",
"li",
"ol",
"p",
"pre",
"strong",
"sub",
"sup",
"table",
"tbody",
"td",
"th",
"thead",
"tr",
"ul",
)
HTML_ATTRIBUTE_WHITELIST = {
'a': ['href', 'title'],
'ol': ['start'],
'td': ['align'],
'th': ['align'],
"a": ["href", "title"],
"ol": ["start"],
"td": ["align"],
"th": ["align"],
}
PROTOCOL_WHITELIST = ('http', 'https')
PROTOCOL_WHITELIST = ("http", "https")
# Regex that finds ordered list markdown that was probably accidental - ones
# being initiated by anything except "1."
BAD_ORDERED_LIST_REGEX = re.compile(
r'((?:\A|\n\n)' # Either the start of the entire text, or a new paragraph
r'(?!1\.)\d+)' # A number that isn't "1"
r'\.\s', # Followed by a period and a space
r"((?:\A|\n\n)" # Either the start of the entire text, or a new paragraph
r"(?!1\.)\d+)" # A number that isn't "1"
r"\.\s" # Followed by a period and a space
)
# Type alias for the "namespaced attr dict" used inside bleach.linkify
@ -95,12 +95,11 @@ NamespacedAttrDict = Dict[Union[Tuple[Optional[str], str], str], str] # noqa
def linkify_protocol_whitelist(
attrs: NamespacedAttrDict,
new: bool = False,
attrs: NamespacedAttrDict, new: bool = False
) -> Optional[NamespacedAttrDict]:
"""bleach.linkify callback: prevent links to non-whitelisted protocols."""
# pylint: disable=unused-argument
href = attrs.get((None, 'href'))
href = attrs.get((None, "href"))
if not href:
return attrs
@ -112,13 +111,13 @@ def linkify_protocol_whitelist(
return attrs
@histogram_timer('markdown_processing')
@histogram_timer("markdown_processing")
def convert_markdown_to_safe_html(markdown: str) -> str:
"""Convert markdown to sanitized HTML."""
# apply custom pre-processing to markdown
markdown = preprocess_markdown(markdown)
markdown_bytes = markdown.encode('utf8')
markdown_bytes = markdown.encode("utf8")
parser = cmark_parser_new(CMARK_OPTS)
for name in CMARK_EXTENSIONS:
@ -134,7 +133,7 @@ def convert_markdown_to_safe_html(markdown: str) -> str:
cmark_parser_free(parser)
cmark_node_free(doc)
html = html_bytes.decode('utf8')
html = html_bytes.decode("utf8")
# apply custom post-processing to HTML
html = postprocess_markdown_html(html)
@ -148,7 +147,7 @@ def preprocess_markdown(markdown: str) -> str:
markdown = escape_accidental_ordered_lists(markdown)
# fix the "shrug" emoji ¯\_(ツ)_/¯ to prevent markdown mangling it
markdown = markdown.replace(r'¯\_(ツ)_/¯', r'¯\\\_(ツ)\_/¯')
markdown = markdown.replace(r"¯\_(ツ)_/¯", r"¯\\\_(ツ)\_/¯")
return markdown
@ -166,19 +165,17 @@ def escape_accidental_ordered_lists(markdown: str) -> str:
numbered list except for "1. ". This will cause a few other edge cases, but
I believe they're less common/important than fixing this common error.
"""
return BAD_ORDERED_LIST_REGEX.sub(r'\1\\. ', markdown)
return BAD_ORDERED_LIST_REGEX.sub(r"\1\\. ", markdown)
def postprocess_markdown_html(html: str) -> str:
"""Apply post-processing to HTML generated by markdown parser."""
# list of tag names to exclude from linkification
linkify_skipped_tags = ['pre']
linkify_skipped_tags = ["pre"]
# search for text that looks like urls and convert to actual links
html = bleach.linkify(
html,
callbacks=[linkify_protocol_whitelist],
skip_tags=linkify_skipped_tags,
html, callbacks=[linkify_protocol_whitelist], skip_tags=linkify_skipped_tags
)
# run the HTML through our custom linkification process as well
@ -187,20 +184,17 @@ def postprocess_markdown_html(html: str) -> str:
return html
def apply_linkification(
html: str,
skip_tags: Optional[List[str]] = None,
) -> str:
def apply_linkification(html: str, skip_tags: Optional[List[str]] = None) -> str:
"""Apply custom linkification filter to convert text patterns to links."""
parser = HTMLParser(namespaceHTMLElements=False)
html_tree = parser.parseFragment(html)
walker_stream = html5lib.getTreeWalker('etree')(html_tree)
walker_stream = html5lib.getTreeWalker("etree")(html_tree)
filtered_html_tree = LinkifyFilter(walker_stream, skip_tags)
serializer = HTMLSerializer(
quote_attr_values='always',
quote_attr_values="always",
omit_optional_tags=False,
sanitize=False,
alphabetical_attributes=False,
@ -224,17 +218,15 @@ class LinkifyFilter(Filter):
# Note: currently specifically excludes paths immediately followed by a
# tilde, but this may be possible to remove once strikethrough is
# implemented (since that's probably what they were trying to do)
GROUP_REFERENCE_REGEX = re.compile(r'(?<!\w)~([\w.]+)\b(?!~)')
GROUP_REFERENCE_REGEX = re.compile(r"(?<!\w)~([\w.]+)\b(?!~)")
# Regex that finds probable references to users. As above, this isn't
# "perfect" either but works as an initial pass with the validity of
# the username checked more carefully later.
USERNAME_REFERENCE_REGEX = re.compile(r'(?<!\w)(?:/?u/|@)([\w-]+)\b')
USERNAME_REFERENCE_REGEX = re.compile(r"(?<!\w)(?:/?u/|@)([\w-]+)\b")
def __init__(
self,
source: NonRecursiveTreeWalker,
skip_tags: Optional[List[str]] = None,
self, source: NonRecursiveTreeWalker, skip_tags: Optional[List[str]] = None
) -> None:
"""Initialize a linkification filter to apply to HTML.
@ -245,28 +237,30 @@ class LinkifyFilter(Filter):
self.skip_tags = skip_tags or []
# always skip the contents of <a> tags in addition to any others
self.skip_tags.append('a')
self.skip_tags.append("a")
def __iter__(self) -> Iterator[dict]:
"""Iterate over the tree, modifying it as necessary before yielding."""
inside_skipped_tags = []
for token in super().__iter__():
if (token['type'] in ('StartTag', 'EmptyTag') and
token['name'] in self.skip_tags):
if (
token["type"] in ("StartTag", "EmptyTag")
and token["name"] in self.skip_tags
):
# if this is the start of a tag we want to skip, add it to the
# list of skipped tags that we're currently inside
inside_skipped_tags.append(token['name'])
inside_skipped_tags.append(token["name"])
elif inside_skipped_tags:
# if we're currently inside any skipped tags, the only thing we
# want to do is look for all the end tags we need to be able to
# finish skipping
if token['type'] == 'EndTag':
if token["type"] == "EndTag":
try:
inside_skipped_tags.remove(token['name'])
inside_skipped_tags.remove(token["name"])
except ValueError:
pass
elif token['type'] == 'Characters':
elif token["type"] == "Characters":
# this is only reachable if inside_skipped_tags is empty, so
# this is a text token not inside a skipped tag - do the actual
# linkification replacements
@ -300,9 +294,7 @@ class LinkifyFilter(Filter):
@staticmethod
def _linkify_tokens(
tokens: List[dict],
filter_regex: Pattern,
linkify_function: Callable,
tokens: List[dict], filter_regex: Pattern, linkify_function: Callable
) -> List[dict]:
"""Check tokens for text that matches a regex and linkify it.
@ -316,21 +308,23 @@ class LinkifyFilter(Filter):
for token in tokens:
# we don't want to touch any tokens other than character ones
if token['type'] != 'Characters':
if token["type"] != "Characters":
new_tokens.append(token)
continue
original_text = token['data']
original_text = token["data"]
current_index = 0
for match in filter_regex.finditer(original_text):
# if there were some characters between the previous match and
# this one, add a token containing those first
if match.start() > current_index:
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:match.start()],
})
new_tokens.append(
{
"type": "Characters",
"data": original_text[current_index : match.start()],
}
)
# call the linkify function to convert this match into tokens
linkified_tokens = linkify_function(match)
@ -342,10 +336,9 @@ class LinkifyFilter(Filter):
# if there's still some text left over, add one more token for it
# (this will be the entire thing if there weren't any matches)
if current_index < len(original_text):
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:],
})
new_tokens.append(
{"type": "Characters", "data": original_text[current_index:]}
)
return new_tokens
@ -360,22 +353,22 @@ class LinkifyFilter(Filter):
# things like "~10" or "~4.5" since that's just going to be someone
# using it in the "approximately" sense. So if the path consists of
# only numbers and/or periods, we won't linkify it
is_numeric = all(char in '0123456789.' for char in group_path)
is_numeric = all(char in "0123456789." for char in group_path)
# if it's a valid group path and not totally numeric, convert to <a>
if is_valid_group_path(group_path) and not is_numeric:
return [
{
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/~{group_path}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/~{group_path}"},
},
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
]
# one of the checks failed, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
@staticmethod
def _tokenize_username_match(match: Match) -> List[dict]:
@ -384,16 +377,16 @@ class LinkifyFilter(Filter):
if is_valid_username(match[1]):
return [
{
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/user/{match[1]}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/user/{match[1]}"},
},
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
]
# the username wasn't valid, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
def sanitize_html(html: str) -> str:

2
tildes/tildes/lib/message.py

@ -1,6 +1,6 @@
"""Functions/constants related to messages."""
WELCOME_MESSAGE_SUBJECT = 'Welcome to the Tildes alpha'
WELCOME_MESSAGE_SUBJECT = "Welcome to the Tildes alpha"
# pylama:ignore=E501
WELCOME_MESSAGE_TEXT = """

11
tildes/tildes/lib/password.py

@ -5,21 +5,22 @@ from hashlib import sha1
from redis import ConnectionError, ResponseError, StrictRedis # noqa
# unix socket path for redis server with the breached passwords bloom filter
BREACHED_PASSWORDS_REDIS_SOCKET = '/run/redis_breached_passwords/socket'
BREACHED_PASSWORDS_REDIS_SOCKET = "/run/redis_breached_passwords/socket"
# Key where the bloom filter of password hashes from data breaches is stored
BREACHED_PASSWORDS_BF_KEY = 'breached_passwords_bloom'
BREACHED_PASSWORDS_BF_KEY = "breached_passwords_bloom"
def is_breached_password(password: str) -> bool:
"""Return whether the password is in the breached-passwords list."""
redis = StrictRedis(unix_socket_path=BREACHED_PASSWORDS_REDIS_SOCKET)
hashed = sha1(password.encode('utf-8')).hexdigest()
hashed = sha1(password.encode("utf-8")).hexdigest()
try:
return bool(redis.execute_command(
'BF.EXISTS', BREACHED_PASSWORDS_BF_KEY, hashed))
return bool(
redis.execute_command("BF.EXISTS", BREACHED_PASSWORDS_BF_KEY, hashed)
)
except (ConnectionError, ResponseError):
# server isn't running, bloom filter doesn't exist or the key is a
# different data type

95
tildes/tildes/lib/ratelimit.py

@ -25,18 +25,17 @@ class RateLimitResult:
"""
def __init__(
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
) -> None:
"""Initialize a RateLimitResult."""
# pylint: disable=too-many-arguments
if is_allowed and time_until_retry is not None:
raise ValueError(
'time_until_retry must be None if is_allowed is True')
raise ValueError("time_until_retry must be None if is_allowed is True")
self.is_allowed = is_allowed
self.total_limit = total_limit
@ -58,7 +57,7 @@ class RateLimitResult:
)
@classmethod
def unlimited_result(cls) -> 'RateLimitResult':
def unlimited_result(cls) -> "RateLimitResult":
"""Return a "blank" result representing an unlimited action."""
return cls(
is_allowed=True,
@ -68,7 +67,7 @@ class RateLimitResult:
)
@classmethod
def from_redis_cell_result(cls, result: List[int]) -> 'RateLimitResult':
def from_redis_cell_result(cls, result: List[int]) -> "RateLimitResult":
"""Convert the response from CL.THROTTLE command to a RateLimitResult.
CL.THROTTLE responds with an array of 5 integers:
@ -98,10 +97,7 @@ class RateLimitResult:
)
@classmethod
def merged_result(
cls,
results: Sequence['RateLimitResult'],
) -> 'RateLimitResult':
def merged_result(cls, results: Sequence["RateLimitResult"]) -> "RateLimitResult":
"""Merge any number of RateLimitResults into a single result.
Basically, the merged result should be the "most restrictive"
@ -125,7 +121,8 @@ class RateLimitResult:
time_until_retry = None
else:
time_until_retry = max(
r.time_until_retry for r in results if r.time_until_retry)
r.time_until_retry for r in results if r.time_until_retry
)
return cls(
is_allowed=all(r.is_allowed for r in results),
@ -140,18 +137,18 @@ class RateLimitResult:
# Retry-After: seconds the client should wait until retrying
if self.time_until_retry:
retry_seconds = int(self.time_until_retry.total_seconds())
response.headers['Retry-After'] = str(retry_seconds)
response.headers["Retry-After"] = str(retry_seconds)
# X-RateLimit-Limit: the total action limit (including used)
response.headers['X-RateLimit-Limit'] = str(self.total_limit)
response.headers["X-RateLimit-Limit"] = str(self.total_limit)
# X-RateLimit-Remaining: remaining actions before client hits the limit
response.headers['X-RateLimit-Remaining'] = str(self.remaining_limit)
response.headers["X-RateLimit-Remaining"] = str(self.remaining_limit)
# X-RateLimit-Reset: epoch timestamp when limit will be back to full
reset_time = utc_now() + self.time_until_max
reset_timestamp = int(reset_time.timestamp())
response.headers['X-RateLimit-Reset'] = str(reset_timestamp)
response.headers["X-RateLimit-Reset"] = str(reset_timestamp)
return response
@ -165,14 +162,14 @@ class RateLimitedAction:
"""
def __init__(
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
) -> None:
"""Initialize the limits on a particular action.
@ -187,10 +184,10 @@ class RateLimitedAction:
"""
# pylint: disable=too-many-arguments
if max_burst and not 1 <= max_burst <= limit:
raise ValueError('max_burst must be at least 1 and <= limit')
raise ValueError("max_burst must be at least 1 and <= limit")
if not (by_user or by_ip):
raise ValueError('At least one of by_user or by_ip must be True')
raise ValueError("At least one of by_user or by_ip must be True")
self.name = name
self.period = period
@ -213,7 +210,7 @@ class RateLimitedAction:
def redis(self) -> StrictRedis:
"""Return the redis connection."""
if not self._redis:
raise RateLimitError('No redis connection set')
raise RateLimitError("No redis connection set")
return self._redis
@ -224,19 +221,14 @@ class RateLimitedAction:
def _build_redis_key(self, by_type: str, value: Any) -> str:
"""Build the Redis key where this rate limit is maintained."""
parts = [
'ratelimit',
self.name,
by_type,
str(value),
]
parts = ["ratelimit", self.name, by_type, str(value)]
return ':'.join(parts)
return ":".join(parts)
def _call_redis_command(self, key: str) -> List[int]:
"""Call the redis-cell CL.THROTTLE command for this action."""
return self.redis.execute_command(
'CL.THROTTLE',
"CL.THROTTLE",
key,
self.max_burst - 1,
self.limit,
@ -246,10 +238,9 @@ class RateLimitedAction:
def check_for_user_id(self, user_id: int) -> RateLimitResult:
"""Check whether a particular user_id can perform this action."""
if not self.by_user:
raise RateLimitError(
'check_for_user_id called on non-user-limited action')
raise RateLimitError("check_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result)
@ -257,22 +248,20 @@ class RateLimitedAction:
def reset_for_user_id(self, user_id: int) -> None:
"""Reset the ratelimit on this action for a particular user_id."""
if not self.by_user:
raise RateLimitError(
'reset_for_user_id called on non-user-limited action')
raise RateLimitError("reset_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
self.redis.delete(key)
def check_for_ip(self, ip_str: str) -> RateLimitResult:
"""Check whether a particular IP can perform this action."""
if not self.by_ip:
raise RateLimitError(
'check_for_ip called on non-IP-limited action')
raise RateLimitError("check_for_ip called on non-IP-limited action")
# check if ip_str is a valid address, will ValueError if not
ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result)
@ -280,23 +269,21 @@ class RateLimitedAction:
def reset_for_ip(self, ip_str: str) -> None:
"""Reset the ratelimit on this action for a particular IP."""
if not self.by_ip:
raise RateLimitError(
'reset_for_ip called on non-user-limited action')
raise RateLimitError("reset_for_ip called on non-user-limited action")
# check if ip_str is a valid address, will ValueError if not
ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
self.redis.delete(key)
# the actual list of actions with rate-limit restrictions
# each action must have a unique name to prevent key collisions
_RATE_LIMITED_ACTIONS = (
RateLimitedAction('login', timedelta(hours=1), 20),
RateLimitedAction('register', timedelta(hours=1), 50),
RateLimitedAction("login", timedelta(hours=1), 20),
RateLimitedAction("register", timedelta(hours=1), 50),
)
# (public) dict to be able to look up the actions by name
RATE_LIMITED_ACTIONS = {
action.name: action for action in _RATE_LIMITED_ACTIONS}
RATE_LIMITED_ACTIONS = {action.name: action for action in _RATE_LIMITED_ACTIONS}

45
tildes/tildes/lib/string.py

@ -20,16 +20,16 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
slug = original.lower()
# remove apostrophes so contractions don't get broken up by underscores
slug = re.sub("['’]", '', slug)
slug = re.sub("['’]", "", slug)
# replace all remaining non-word characters with underscores
slug = re.sub(r'\W+', '_', slug)
slug = re.sub(r"\W+", "_", slug)
# remove any consecutive underscores
slug = re.sub('_{2,}', '_', slug)
slug = re.sub("_{2,}", "_", slug)
# remove "hanging" underscores on the start and/or end
slug = slug.strip('_')
slug = slug.strip("_")
# url-encode the slug
encoded_slug = quote(slug)
@ -42,7 +42,7 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
# Truncating a url-encoded slug can be tricky if there are any multi-byte
# unicode characters, since the %-encoded forms of them can be quite long.
# Check to see if the slug looks like it might contain any of those.
maybe_multi_bytes = bool(re.search('%..%', encoded_slug))
maybe_multi_bytes = bool(re.search("%..%", encoded_slug))
# if that matched, we need to take a more complicated approach
if maybe_multi_bytes:
@ -50,10 +50,7 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
# simple truncate - break at underscore if possible, no overflow string
return truncate_string(
encoded_slug,
max_length,
truncate_at_chars='_',
overflow_str=None,
encoded_slug, max_length, truncate_at_chars="_", overflow_str=None
)
@ -62,7 +59,7 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
# instead of the normal method of truncating "backwards" from the end of
# the string, build it up one encoded character at a time from the start
# until it's too long
encoded_slug = ''
encoded_slug = ""
for character in original:
encoded_char = quote(character)
@ -82,7 +79,7 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
# determining the word edges is not simple.
acceptable_truncation = 0.7
truncated_slug = truncate_string_at_char(encoded_slug, '_')
truncated_slug = truncate_string_at_char(encoded_slug, "_")
if len(truncated_slug) / len(encoded_slug) >= acceptable_truncation:
return truncated_slug
@ -91,10 +88,10 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
def truncate_string(
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = '...',
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = "...",
) -> str:
"""Truncate a string to be no longer than a specified length.
@ -109,7 +106,7 @@ def truncate_string(
string will be kept.
"""
if overflow_str is None:
overflow_str = ''
overflow_str = ""
# no need to do anything if the string is already short enough
if len(original) <= length:
@ -117,7 +114,7 @@ def truncate_string(
# cut the string down to the max desired length (leaving space for the
# overflow string if one is specified)
truncated = original[:length - len(overflow_str)]
truncated = original[: length - len(overflow_str)]
# if we don't want to truncate at particular characters, we're done
if not truncate_at_chars:
@ -167,7 +164,7 @@ def simplify_string(original: str) -> str:
simplified = _sanitize_characters(original)
# replace consecutive spaces with a single space
simplified = re.sub(r'\s{2,}', ' ', simplified)
simplified = re.sub(r"\s{2,}", " ", simplified)
# remove any remaining leading/trailing whitespace
simplified = simplified.strip()
@ -182,16 +179,16 @@ def _sanitize_characters(original: str) -> str:
for char in original:
category = unicodedata.category(char)
if category.startswith('Z'):
if category.startswith("Z"):
# "separator" chars - replace with a normal space
final_characters.append(' ')
elif category.startswith('C'):
final_characters.append(" ")
elif category.startswith("C"):
# "other" chars (control, formatting, etc.) - filter them out
# except for newlines, which are replaced with normal spaces
if char == '\n':
final_characters.append(' ')
if char == "\n":
final_characters.append(" ")
else:
# any other type of character, just keep it
final_characters.append(char)
return ''.join(final_characters)
return "".join(final_characters)

4
tildes/tildes/lib/url.py

@ -8,9 +8,9 @@ def get_domain_from_url(url: str, strip_www: bool = True) -> str:
domain = urlparse(url).netloc
if not domain:
raise ValueError('Invalid url or domain could not be determined')
raise ValueError("Invalid url or domain could not be determined")
if strip_www and domain.startswith('www.'):
if strip_www and domain.startswith("www."):
domain = domain[4:]
return domain

60
tildes/tildes/metrics.py

@ -11,50 +11,32 @@ from prometheus_client.core import _LabelWrapper
_COUNTERS = {
'votes': Counter(
'tildes_votes_total',
'Votes',
labelnames=['target_type'],
),
'comments': Counter('tildes_comments_total', 'Comments'),
'invite_code_failures': Counter(
'tildes_invite_code_failures_total',
'Invite Code Failures',
),
'logins': Counter('tildes_logins_total', 'Login Attempts'),
'login_failures': Counter(
'tildes_login_failures_total',
'Login Failures',
),
'messages': Counter(
'tildes_messages_total',
'Messages',
labelnames=['type'],
),
'registrations': Counter(
'tildes_registrations_total',
'User Registrations',
),
'topics': Counter('tildes_topics_total', 'Topics', labelnames=['type']),
'subscriptions': Counter('tildes_subscriptions_total', 'Subscriptions'),
'unsubscriptions': Counter(
'tildes_unsubscriptions_total',
'Unsubscriptions',
"votes": Counter("tildes_votes_total", "Votes", labelnames=["target_type"]),
"comments": Counter("tildes_comments_total", "Comments"),
"invite_code_failures": Counter(
"tildes_invite_code_failures_total", "Invite Code Failures"
),
"logins": Counter("tildes_logins_total", "Login Attempts"),
"login_failures": Counter("tildes_login_failures_total", "Login Failures"),
"messages": Counter("tildes_messages_total", "Messages", labelnames=["type"]),
"registrations": Counter("tildes_registrations_total", "User Registrations"),
"topics": Counter("tildes_topics_total", "Topics", labelnames=["type"]),
"subscriptions": Counter("tildes_subscriptions_total", "Subscriptions"),
"unsubscriptions": Counter("tildes_unsubscriptions_total", "Unsubscriptions"),
}
_HISTOGRAMS = {
'markdown_processing': Histogram(
'tildes_markdown_processing_seconds',
'Markdown processing',
"markdown_processing": Histogram(
"tildes_markdown_processing_seconds",
"Markdown processing",
buckets=[.001, .0025, .005, .01, 0.025, .05, .1, .5, 1.0],
),
'comment_tree_sorting': Histogram(
'tildes_comment_tree_sorting_seconds',
'Comment tree sorting time',
labelnames=['num_comments_range', 'order'],
"comment_tree_sorting": Histogram(
"tildes_comment_tree_sorting_seconds",
"Comment tree sorting time",
labelnames=["num_comments_range", "order"],
buckets=[.00001, .0001, .001, .01, .05, .1, .5, 1.0],
)
),
}
@ -63,7 +45,7 @@ def incr_counter(name: str, amount: int = 1, **labels: str) -> None:
try:
counter = _COUNTERS[name]
except KeyError:
raise ValueError('Invalid counter name')
raise ValueError("Invalid counter name")
if isinstance(counter, _LabelWrapper):
counter = counter.labels(**labels)
@ -76,7 +58,7 @@ def get_histogram(name: str, **labels: str) -> Histogram:
try:
hist = _HISTOGRAMS[name]
except KeyError:
raise ValueError('Invalid histogram name')
raise ValueError("Invalid histogram name")
if isinstance(hist, _LabelWrapper):
hist = hist.labels(**labels)

97
tildes/tildes/models/comment/comment.py

@ -5,14 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import deferred, relationship
from sqlalchemy.sql.expression import text
@ -60,47 +53,40 @@ class Comment(DatabaseModel):
schema_class = CommentSchema
__tablename__ = 'comments'
__tablename__ = "comments"
comment_id: int = Column(Integer, primary_key=True)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
index=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, index=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
parent_comment_id: Optional[int] = Column(
Integer,
ForeignKey('comments.comment_id'),
index=True,
Integer, ForeignKey("comments.comment_id"), index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column(Boolean, nullable=False, server_default='false')
is_removed: bool = Column(Boolean, nullable=False, server_default="false")
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
_markdown: str = deferred(Column('markdown', Text, nullable=False))
_markdown: str = deferred(Column("markdown", Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
user: User = relationship('User', lazy=False, innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
parent_comment: Optional['Comment'] = relationship(
'Comment', uselist=False, remote_side=[comment_id])
user: User = relationship("User", lazy=False, innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
parent_comment: Optional["Comment"] = relationship(
"Comment", uselist=False, remote_side=[comment_id]
)
@hybrid_property
def markdown(self) -> str:
@ -117,20 +103,19 @@ class Comment(DatabaseModel):
self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now()
def __repr__(self) -> str:
"""Display the comment's ID as its repr format."""
return f'<Comment ({self.comment_id})>'
return f"<Comment ({self.comment_id})>"
def __init__(
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional['Comment'] = None,
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional["Comment"] = None,
) -> None:
"""Create a new comment."""
self.topic = topic
@ -142,7 +127,7 @@ class Comment(DatabaseModel):
self.markdown = markdown
incr_counter('comments')
incr_counter("comments")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
@ -156,49 +141,49 @@ class Comment(DatabaseModel):
# - removed comments can only be viewed by admins and the author
# - otherwise, everyone can view
if self.is_removed:
acl.append((Allow, 'admin', 'view'))
acl.append((Allow, self.user_id, 'view'))
acl.append((Deny, Everyone, 'view'))
acl.append((Allow, "admin", "view"))
acl.append((Allow, self.user_id, "view"))
acl.append((Deny, Everyone, "view"))
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# vote:
# - removed comments can't be voted on by anyone
# - otherwise, logged-in users except the author can vote
if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# tag:
# - temporary: nobody can tag comments
acl.append((Deny, Everyone, 'tag'))
acl.append((Deny, Everyone, "tag"))
# reply:
# - removed comments can't be replied to by anyone
# - if the topic is locked, only admins can reply
# - otherwise, logged-in users can reply
if self.is_removed:
acl.append((Deny, Everyone, 'reply'))
acl.append((Deny, Everyone, "reply"))
if self.topic.is_locked:
acl.append((Allow, 'admin', 'reply'))
acl.append((Deny, Everyone, 'reply'))
acl.append((Allow, "admin", "reply"))
acl.append((Deny, Everyone, "reply"))
acl.append((Allow, Authenticated, 'reply'))
acl.append((Allow, Authenticated, "reply"))
# edit:
# - only the author can edit
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete:
# - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# mark_read:
# - logged-in users can mark comments read
acl.append((Allow, Authenticated, 'mark_read'))
acl.append((Allow, Authenticated, "mark_read"))
acl.append(DENY_ALL)
@ -220,7 +205,7 @@ class Comment(DatabaseModel):
@property
def permalink(self) -> str:
"""Return the permalink for this comment."""
return f'{self.topic.permalink}#comment-{self.comment_id36}'
return f"{self.topic.permalink}#comment-{self.comment_id36}"
@property
def parent_comment_permalink(self) -> str:
@ -228,7 +213,7 @@ class Comment(DatabaseModel):
if not self.parent_comment_id:
raise AttributeError
return f'{self.topic.permalink}#comment-{self.parent_comment_id36}'
return f"{self.topic.permalink}#comment-{self.parent_comment_id36}"
@property
def tag_counts(self) -> Counter:

71
tildes/tildes/models/comment/comment_notification.py

@ -29,38 +29,27 @@ class CommentNotification(DatabaseModel):
decrement num_unread_notifications for the relevant user.
"""
__tablename__ = 'comment_notifications'
__tablename__ = "comment_notifications"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
notification_type: CommentNotificationType = Column(
ENUM(CommentNotificationType), nullable=False)
ENUM(CommentNotificationType), nullable=False
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
is_unread: bool = Column(
Boolean, nullable=False, server_default='true', index=True)
is_unread: bool = Column(Boolean, nullable=False, server_default="true", index=True)
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__(
self,
user: User,
comment: Comment,
notification_type: CommentNotificationType,
self, user: User, comment: Comment, notification_type: CommentNotificationType
) -> None:
"""Create a new notification for a user from a comment."""
self.user = user
@ -70,7 +59,7 @@ class CommentNotification(DatabaseModel):
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
acl = []
acl.append((Allow, self.user_id, 'mark_read'))
acl.append((Allow, self.user_id, "mark_read"))
acl.append(DENY_ALL)
return acl
@ -91,17 +80,12 @@ class CommentNotification(DatabaseModel):
@classmethod
def get_mentions_for_comment(
cls,
db_session: Session,
comment: Comment,
) -> List['CommentNotification']:
cls, db_session: Session, comment: Comment
) -> List["CommentNotification"]:
"""Get a list of notifications for user mentions in the comment."""
notifications = []
raw_names = re.findall(
LinkifyFilter.USERNAME_REFERENCE_REGEX,
comment.markdown,
)
raw_names = re.findall(LinkifyFilter.USERNAME_REFERENCE_REGEX, comment.markdown)
users_to_mention = (
db_session.query(User)
.filter(User.username.in_(raw_names)) # type: ignore
@ -124,17 +108,18 @@ class CommentNotification(DatabaseModel):
continue
mention_notification = cls(
user, comment, CommentNotificationType.USER_MENTION)
user, comment, CommentNotificationType.USER_MENTION
)
notifications.append(mention_notification)
return notifications
@staticmethod
def prevent_duplicate_notifications(
db_session: Session,
comment: Comment,
new_notifications: List['CommentNotification'],
) -> Tuple[List['CommentNotification'], List['CommentNotification']]:
db_session: Session,
comment: Comment,
new_notifications: List["CommentNotification"],
) -> Tuple[List["CommentNotification"], List["CommentNotification"]]:
"""Filter new notifications for edited comments.
Protect against sending a notification for the same comment to
@ -149,13 +134,13 @@ class CommentNotification(DatabaseModel):
that need to be added, as they're new.
"""
previous_notifications = (
db_session
.query(CommentNotification)
db_session.query(CommentNotification)
.filter(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
).all()
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
.all()
)
new_mention_user_ids = [
@ -167,12 +152,14 @@ class CommentNotification(DatabaseModel):
]
to_delete = [
notification for notification in previous_notifications
notification
for notification in previous_notifications
if notification.user.user_id not in new_mention_user_ids
]
to_add = [
notification for notification in new_notifications
notification
for notification in new_notifications
if notification.user.user_id not in previous_mention_user_ids
]

10
tildes/tildes/models/comment/comment_notification_query.py

@ -17,7 +17,7 @@ class CommentNotificationQuery(ModelQuery):
"""Initialize a CommentNotificationQuery for the request."""
super().__init__(CommentNotification, request)
def _attach_extra_data(self) -> 'CommentNotificationQuery':
def _attach_extra_data(self) -> "CommentNotificationQuery":
"""Attach the user's comment votes to the query."""
vote_subquery = (
self.request.query(CommentVote)
@ -26,16 +26,16 @@ class CommentNotificationQuery(ModelQuery):
CommentVote.user == self.request.user,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)
def join_all_relationships(self) -> 'CommentNotificationQuery':
def join_all_relationships(self) -> "CommentNotificationQuery":
"""Eagerly join the comment, topic, and group to the notification."""
self = self.options(
joinedload(CommentNotification.comment)
.joinedload('topic')
.joinedload('group')
.joinedload("topic")
.joinedload("group")
)
return self

6
tildes/tildes/models/comment/comment_query.py

@ -21,14 +21,14 @@ class CommentQuery(PaginatedQuery):
"""
super().__init__(Comment, request)
def _attach_extra_data(self) -> 'CommentQuery':
def _attach_extra_data(self) -> "CommentQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
return self._attach_vote_data()
def _attach_vote_data(self) -> 'CommentQuery':
def _attach_vote_data(self) -> "CommentQuery":
"""Add a subquery to include whether the user has voted."""
vote_subquery = (
self.request.query(CommentVote)
@ -37,7 +37,7 @@ class CommentQuery(PaginatedQuery):
CommentVote.user_id == self.request.user.user_id,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)

29
tildes/tildes/models/comment/comment_tag.py

@ -16,37 +16,24 @@ from .comment import Comment
class CommentTag(DatabaseModel):
"""Model for the tags attached to comments by users."""
__tablename__ = 'comment_tags'
__tablename__ = "comment_tags"
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
tag: CommentTagOption = Column(
ENUM(CommentTagOption), nullable=False, primary_key=True)
ENUM(CommentTagOption), nullable=False, primary_key=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
comment: Comment = relationship(
Comment, backref=backref('tags', lazy=False))
comment: Comment = relationship(Comment, backref=backref("tags", lazy=False))
def __init__(
self,
comment: Comment,
user: User,
tag: CommentTagOption,
) -> None:
def __init__(self, comment: Comment, user: User, tag: CommentTagOption) -> None:
"""Add a new tag to a comment."""
self.comment_id = comment.comment_id
self.user_id = user.user_id

23
tildes/tildes/models/comment/comment_tree.py

@ -18,11 +18,7 @@ class CommentTree:
descendants (if not, it can be pruned from the tree)
"""
def __init__(
self,
comments: Sequence[Comment],
sort: CommentSortOption,
) -> None:
def __init__(self, comments: Sequence[Comment], sort: CommentSortOption) -> None:
"""Create a sorted CommentTree from a flat list of Comments."""
self.tree: List[Comment] = []
self.sort = sort
@ -76,10 +72,7 @@ class CommentTree:
self.tree.append(comment)
@staticmethod
def _sort_tree(
tree: List[Comment],
sort: CommentSortOption,
) -> List[Comment]:
def _sort_tree(tree: List[Comment], sort: CommentSortOption) -> List[Comment]:
"""Sort the tree by the desired ordering (recursively).
Because Python's sorted() function is stable, the ordering of any
@ -149,18 +142,18 @@ class CommentTree:
# make an "order of magnitude" label based on the number of comments
if num_comments == 0:
raise ValueError('Attempting to time an empty comment tree sort')
raise ValueError("Attempting to time an empty comment tree sort")
if num_comments < 10:
num_comments_range = '1 - 9'
num_comments_range = "1 - 9"
elif num_comments < 100:
num_comments_range = '10 - 99'
num_comments_range = "10 - 99"
elif num_comments < 1000:
num_comments_range = '100 - 999'
num_comments_range = "100 - 999"
else:
num_comments_range = '1000+'
num_comments_range = "1000+"
return get_histogram(
'comment_tree_sorting',
"comment_tree_sorting",
num_comments_range=num_comments_range,
order=self.sort.name,
)

20
tildes/tildes/models/comment/comment_vote.py

@ -21,33 +21,27 @@ class CommentVote(DatabaseModel):
column for the relevant comment.
"""
__tablename__ = 'comment_votes'
__tablename__ = "comment_votes"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__(self, user: User, comment: Comment) -> None:
"""Create a new vote on a comment."""
self.user = user
self.comment = comment
incr_counter('votes', target_type='comment')
incr_counter("votes", target_type="comment")

34
tildes/tildes/models/database_model.py

@ -11,36 +11,31 @@ from sqlalchemy.schema import MetaData
from sqlalchemy.sql.schema import Table
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
# SQLAlchemy naming convention for constraints and indexes
NAMING_CONVENTION = {
'pk': 'pk_%(table_name)s',
'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s',
'ix': 'ix_%(table_name)s_%(column_0_name)s',
'ck': 'ck_%(table_name)s_%(constraint_name)s',
'uq': 'uq_%(table_name)s_%(column_0_name)s',
"pk": "pk_%(table_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"ix": "ix_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
}
def attach_set_listener(
class_: Type['DatabaseModelBase'],
attribute: str,
instance: 'DatabaseModelBase',
class_: Type["DatabaseModelBase"], attribute: str, instance: "DatabaseModelBase"
) -> None:
"""Attach the SQLAlchemy ORM "set" attribute listener."""
# pylint: disable=unused-argument
def set_handler(
target: 'DatabaseModelBase',
value: Any,
oldvalue: Any,
initiator: Any,
target: "DatabaseModelBase", value: Any, oldvalue: Any, initiator: Any
) -> Any:
"""Handle an SQLAlchemy ORM "set" attribute event."""
# pylint: disable=protected-access
return target._validate_new_value(attribute, value)
event.listen(instance, 'set', set_handler, retval=True)
event.listen(instance, "set", set_handler, retval=True)
class DatabaseModelBase:
@ -71,8 +66,7 @@ class DatabaseModelBase:
key columns used in __eq__, as recommended in the Python documentation.
"""
primary_key_values = tuple(
getattr(self, column.name)
for column in self.__table__.primary_key
getattr(self, column.name) for column in self.__table__.primary_key
)
return hash(primary_key_values)
@ -82,7 +76,7 @@ class DatabaseModelBase:
if not self.schema_class:
raise AttributeError
if not hasattr(self, '_schema'):
if not hasattr(self, "_schema"):
self._schema = self.schema_class(partial=True) # noqa
return self._schema
@ -112,7 +106,7 @@ class DatabaseModelBase:
# set starts with an underscore, assume that it's due to being set up
# as a hybrid property, and remove the underscore prefix when looking
# for a field to validate against.
if attribute.startswith('_'):
if attribute.startswith("_"):
attribute = attribute[1:]
field = self.schema.fields.get(attribute)
@ -126,13 +120,13 @@ class DatabaseModelBase:
DatabaseModel = declarative_base( # pylint: disable=invalid-name
cls=DatabaseModelBase,
name='DatabaseModel',
name="DatabaseModel",
metadata=MetaData(naming_convention=NAMING_CONVENTION),
)
# attach the listener for SQLAlchemy ORM attribute "set" events to all models
event.listen(DatabaseModel, 'attribute_instrument', attach_set_listener)
event.listen(DatabaseModel, "attribute_instrument", attach_set_listener)
# associate JSONB columns with MutableDict so value changes are detected
MutableDict.associate_with(JSONB)

44
tildes/tildes/models/group/group.py

@ -4,15 +4,7 @@ from datetime import datetime
from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
Index,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, CheckConstraint, Column, Index, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text
from sqlalchemy_utils import Ltree, LtreeType
@ -31,7 +23,7 @@ class Group(DatabaseModel):
schema_class = GroupSchema
__tablename__ = 'groups'
__tablename__ = "groups"
group_id: int = Column(Integer, primary_key=True)
path: Ltree = Column(LtreeType, nullable=False, index=True, unique=True)
@ -39,36 +31,34 @@ class Group(DatabaseModel):
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
short_description: Optional[str] = Column(
Text,
CheckConstraint(
f'LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}',
name='short_description_length',
)
f"LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}",
name="short_description_length",
),
)
num_subscriptions: int = Column(
Integer, nullable=False, server_default='0')
num_subscriptions: int = Column(Integer, nullable=False, server_default="0")
is_admin_posting_only: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
# Create a GiST index on path as well as the btree one that will be created
# by the index=True/unique=True keyword args to Column above. The GiST
# index supports additional operators for ltree queries: @>, <@, @, ~, ?
__table_args__ = (
Index('ix_groups_path_gist', path, postgresql_using='gist'),
)
__table_args__ = (Index("ix_groups_path_gist", path, postgresql_using="gist"),)
def __repr__(self) -> str:
"""Display the group's path and ID as its repr format."""
return f'<Group {self.path} ({self.group_id})>'
return f"<Group {self.path} ({self.group_id})>"
def __str__(self) -> str:
"""Use the group path for the string representation."""
return str(self.path)
def __lt__(self, other: 'Group') -> bool:
def __lt__(self, other: "Group") -> bool:
"""Order groups by their string representation."""
return str(self) < str(other)
@ -83,20 +73,20 @@ class Group(DatabaseModel):
# view:
# - all groups can be viewed by everyone
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# subscribe:
# - all groups can be subscribed to by logged-in users
acl.append((Allow, Authenticated, 'subscribe'))
acl.append((Allow, Authenticated, "subscribe"))
# post_topic:
# - only admins can post in admin-posting-only groups
# - otherwise, all logged-in users can post
if self.is_admin_posting_only:
acl.append((Allow, 'admin', 'post_topic'))
acl.append((Deny, Everyone, 'post_topic'))
acl.append((Allow, "admin", "post_topic"))
acl.append((Deny, Everyone, "post_topic"))
acl.append((Allow, Authenticated, 'post_topic'))
acl.append((Allow, Authenticated, "post_topic"))
acl.append(DENY_ALL)

6
tildes/tildes/models/group/group_query.py

@ -21,14 +21,14 @@ class GroupQuery(ModelQuery):
"""
super().__init__(Group, request)
def _attach_extra_data(self) -> 'GroupQuery':
def _attach_extra_data(self) -> "GroupQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
return self._attach_subscription_data()
def _attach_subscription_data(self) -> 'GroupQuery':
def _attach_subscription_data(self) -> "GroupQuery":
"""Add a subquery to include whether the user is subscribed."""
subscription_subquery = (
self.request.query(GroupSubscription)
@ -37,7 +37,7 @@ class GroupQuery(ModelQuery):
GroupSubscription.user == self.request.user,
)
.exists()
.label('user_subscribed')
.label("user_subscribed")
)
return self.add_columns(subscription_subquery)

20
tildes/tildes/models/group/group_subscription.py

@ -21,33 +21,27 @@ class GroupSubscription(DatabaseModel):
num_subscriptions column for the relevant group.
"""
__tablename__ = 'group_subscriptions'
__tablename__ = "group_subscriptions"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True, backref='subscriptions')
group: Group = relationship('Group', innerjoin=True, lazy=False)
user: User = relationship("User", innerjoin=True, backref="subscriptions")
group: Group = relationship("Group", innerjoin=True, lazy=False)
def __init__(self, user: User, group: Group) -> None:
"""Create a new subscription to a group."""
self.user = user
self.group = group
incr_counter('subscriptions')
incr_counter("subscriptions")

135
tildes/tildes/models/log/log.py

@ -3,15 +3,7 @@
from typing import Any, Dict, Optional
from pyramid.request import Request
from sqlalchemy import (
BigInteger,
Column,
event,
ForeignKey,
Integer,
Table,
TIMESTAMP,
)
from sqlalchemy import BigInteger, Column, event, ForeignKey, Integer, Table, TIMESTAMP
from sqlalchemy.dialects.postgresql import ENUM, INET, JSONB
from sqlalchemy.engine import Connection
from sqlalchemy.ext.declarative import declared_attr
@ -23,7 +15,7 @@ from tildes.models import DatabaseModel
from tildes.models.topic import Topic
class BaseLog():
class BaseLog:
"""Mixin class with the shared columns/relationships for log classes."""
@declared_attr
@ -34,7 +26,7 @@ class BaseLog():
@declared_attr
def user_id(self) -> Column:
"""Return the user_id column."""
return Column(Integer, ForeignKey('users.user_id'), index=True)
return Column(Integer, ForeignKey("users.user_id"), index=True)
@declared_attr
def event_type(self) -> Column:
@ -53,7 +45,7 @@ class BaseLog():
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
@declared_attr
@ -64,21 +56,21 @@ class BaseLog():
@declared_attr
def user(self) -> Any:
"""Return the user relationship."""
return relationship('User', lazy=False)
return relationship("User", lazy=False)
class Log(DatabaseModel, BaseLog):
"""Model for a basic log entry."""
__tablename__ = 'log'
__tablename__ = "log"
INHERITED_TABLES = ['log_topics']
INHERITED_TABLES = ["log_topics"]
def __init__(
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Create a new log entry.
@ -97,19 +89,20 @@ class Log(DatabaseModel, BaseLog):
class LogTopic(DatabaseModel, BaseLog):
"""Model for a log entry related to a specific topic."""
__tablename__ = 'log_topics'
__tablename__ = "log_topics"
topic_id: int = Column(
Integer, ForeignKey('topics.topic_id'), index=True, nullable=False)
Integer, ForeignKey("topics.topic_id"), index=True, nullable=False
)
topic: Topic = relationship('Topic')
topic: Topic = relationship("Topic")
def __init__(
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Create a new log entry related to a specific topic."""
# pylint: disable=non-parent-init-called
@ -122,57 +115,55 @@ class LogTopic(DatabaseModel, BaseLog):
if self.event_type == LogEventType.TOPIC_TAG:
return self._tag_event_description()
elif self.event_type == LogEventType.TOPIC_MOVE:
old_group = self.info['old'] # noqa
new_group = self.info['new'] # noqa
return f'moved from ~{old_group} to ~{new_group}'
old_group = self.info["old"] # noqa
new_group = self.info["new"] # noqa
return f"moved from ~{old_group} to ~{new_group}"
elif self.event_type == LogEventType.TOPIC_LOCK:
return 'locked comments'
return "locked comments"
elif self.event_type == LogEventType.TOPIC_UNLOCK:
return 'unlocked comments'
return "unlocked comments"
elif self.event_type == LogEventType.TOPIC_TITLE_EDIT:
old_title = self.info['old'] # noqa
new_title = self.info['new'] # noqa
old_title = self.info["old"] # noqa
new_title = self.info["new"] # noqa
return f'changed title from "{old_title}" to "{new_title}"'
return f'performed action {self.event_type.name}' # noqa
return f"performed action {self.event_type.name}" # noqa
def _tag_event_description(self) -> str:
"""Return a description of a TOPIC_TAG event as a string."""
if self.event_type != LogEventType.TOPIC_TAG:
raise TypeError
old_tags = set(self.info['old']) # noqa
new_tags = set(self.info['new']) # noqa
old_tags = set(self.info["old"]) # noqa
new_tags = set(self.info["new"]) # noqa
added_tags = new_tags - old_tags
removed_tags = old_tags - new_tags
description = ''
description = ""
if added_tags:
tag_str = ', '.join([f"'{tag}'" for tag in added_tags])
tag_str = ", ".join([f"'{tag}'" for tag in added_tags])
if len(added_tags) == 1:
description += f'added tag {tag_str}'
description += f"added tag {tag_str}"
else:
description += f'added tags {tag_str}'
description += f"added tags {tag_str}"
if removed_tags:
description += ' and '
description += " and "
if removed_tags:
tag_str = ', '.join([f"'{tag}'" for tag in removed_tags])
tag_str = ", ".join([f"'{tag}'" for tag in removed_tags])
if len(removed_tags) == 1:
description += f'removed tag {tag_str}'
description += f"removed tag {tag_str}"
else:
description += f'removed tags {tag_str}'
description += f"removed tags {tag_str}"
return description
@event.listens_for(Log.__table__, 'after_create')
@event.listens_for(Log.__table__, "after_create")
def create_inherited_tables(
target: Table,
connection: Connection,
**kwargs: Any,
target: Table, connection: Connection, **kwargs: Any
) -> None:
"""Create all the tables that inherit from the base "log" one."""
# pylint: disable=unused-argument
@ -180,43 +171,39 @@ def create_inherited_tables(
# log_topics
connection.execute(
'CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)')
"CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)"
)
fk_name = naming['fk'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
'referred_table_name': 'topics',
fk_name = naming["fk"] % {
"table_name": "log_topics",
"column_0_name": "topic_id",
"referred_table_name": "topics",
}
connection.execute(
f'ALTER TABLE log_topics ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (topic_id) REFERENCES topics (topic_id)'
f"ALTER TABLE log_topics ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (topic_id) REFERENCES topics (topic_id)"
)
ix_name = naming['ix'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
}
connection.execute(f'CREATE INDEX {ix_name} ON log_topics (topic_id)')
ix_name = naming["ix"] % {"table_name": "log_topics", "column_0_name": "topic_id"}
connection.execute(f"CREATE INDEX {ix_name} ON log_topics (topic_id)")
# duplicate all the indexes/constraints from the base log table
for table in Log.INHERITED_TABLES:
pk_name = naming['pk'] % {'table_name': table}
pk_name = naming["pk"] % {"table_name": table}
connection.execute(
f'ALTER TABLE {table} '
f'ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)"
)
for col in ('event_time', 'event_type', 'ip_address', 'user_id'):
ix_name = naming['ix'] % {
'table_name': table, 'column_0_name': col}
connection.execute(f'CREATE INDEX {ix_name} ON {table} ({col})')
for col in ("event_time", "event_type", "ip_address", "user_id"):
ix_name = naming["ix"] % {"table_name": table, "column_0_name": col}
connection.execute(f"CREATE INDEX {ix_name} ON {table} ({col})")
fk_name = naming['fk'] % {
'table_name': table,
'column_0_name': 'user_id',
'referred_table_name': 'users',
fk_name = naming["fk"] % {
"table_name": table,
"column_0_name": "user_id",
"referred_table_name": "users",
}
connection.execute(
f'ALTER TABLE {table} ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (user_id) REFERENCES users (user_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (user_id) REFERENCES users (user_id)"
)

85
tildes/tildes/models/message/message.py

@ -55,78 +55,70 @@ class MessageConversation(DatabaseModel):
schema_class = MessageConversationSchema
__tablename__ = 'message_conversations'
__tablename__ = "message_conversations"
conversation_id: int = Column(Integer, primary_key=True)
sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
recipient_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
subject: str = Column(
Text,
CheckConstraint(
f'LENGTH(subject) <= {SUBJECT_MAX_LENGTH}',
name='subject_length',
f"LENGTH(subject) <= {SUBJECT_MAX_LENGTH}", name="subject_length"
),
nullable=False,
)
markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
num_replies: int = Column(Integer, nullable=False, server_default='0')
last_reply_time: Optional[datetime] = Column(
TIMESTAMP(timezone=True), index=True)
num_replies: int = Column(Integer, nullable=False, server_default="0")
last_reply_time: Optional[datetime] = Column(TIMESTAMP(timezone=True), index=True)
unread_user_ids: List[int] = Column(
ARRAY(Integer), nullable=False, server_default='{}')
ARRAY(Integer), nullable=False, server_default="{}"
)
sender: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[sender_id])
"User", lazy=False, innerjoin=True, foreign_keys=[sender_id]
)
recipient: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[recipient_id])
replies: Sequence['MessageReply'] = relationship(
'MessageReply', order_by='MessageReply.created_time')
"User", lazy=False, innerjoin=True, foreign_keys=[recipient_id]
)
replies: Sequence["MessageReply"] = relationship(
"MessageReply", order_by="MessageReply.created_time"
)
# Create a GIN index on the unread_user_ids column using the gin__int_ops
# operator class supplied by the intarray module. This should be the best
# index for "array contains" queries.
__table_args__ = (
Index(
'ix_message_conversations_unread_user_ids_gin',
"ix_message_conversations_unread_user_ids_gin",
unread_user_ids,
postgresql_using='gin',
postgresql_ops={'unread_user_ids': 'gin__int_ops'},
postgresql_using="gin",
postgresql_ops={"unread_user_ids": "gin__int_ops"},
),
)
def __init__(
self,
sender: User,
recipient: User,
subject: str,
markdown: str,
self, sender: User, recipient: User, subject: str, markdown: str
) -> None:
"""Create a new message conversation between two users."""
self.sender_id = sender.user_id
self.recipient_id = recipient.user_id
self.unread_user_ids = ([self.recipient_id])
self.unread_user_ids = [self.recipient_id]
self.subject = subject
self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='conversation')
incr_counter("messages", type="conversation")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
@ -163,7 +155,7 @@ class MessageConversation(DatabaseModel):
vice versa.
"""
if not self.is_participant(viewer):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
if viewer == self.sender:
return self.recipient
@ -173,7 +165,7 @@ class MessageConversation(DatabaseModel):
def is_unread_by_user(self, user: User) -> bool:
"""Return whether the conversation is unread by the specified user."""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
return user.user_id in self.unread_user_ids
@ -184,9 +176,9 @@ class MessageConversation(DatabaseModel):
worry about duplicate values, race conditions, etc.
"""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
union = MessageConversation.unread_user_ids.op('|') # type: ignore
union = MessageConversation.unread_user_ids.op("|") # type: ignore
self.unread_user_ids = union(user.user_id)
def mark_read_for_user(self, user: User) -> None:
@ -197,11 +189,12 @@ class MessageConversation(DatabaseModel):
race conditions, etc.
"""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
user_id = user.user_id
self.unread_user_ids = ( # type: ignore
MessageConversation.unread_user_ids - user_id) # type: ignore
MessageConversation.unread_user_ids - user_id # type: ignore
)
class MessageReply(DatabaseModel):
@ -217,37 +210,31 @@ class MessageReply(DatabaseModel):
schema_class = MessageReplySchema
__tablename__ = 'message_replies'
__tablename__ = "message_replies"
reply_id: int = Column(Integer, primary_key=True)
conversation_id: int = Column(
Integer,
ForeignKey('message_conversations.conversation_id'),
ForeignKey("message_conversations.conversation_id"),
nullable=False,
index=True,
)
sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
sender: User = relationship('User', lazy=False, innerjoin=True)
sender: User = relationship("User", lazy=False, innerjoin=True)
def __init__(
self,
conversation: MessageConversation,
sender: User,
markdown: str,
self, conversation: MessageConversation, sender: User, markdown: str
) -> None:
"""Add a new reply to a message conversation."""
self.conversation_id = conversation.conversation_id
@ -255,7 +242,7 @@ class MessageReply(DatabaseModel):
self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='reply')
incr_counter("messages", type="reply")
@property
def reply_id36(self) -> str:

37
tildes/tildes/models/model_query.py

@ -8,7 +8,7 @@ from sqlalchemy.orm import Load, undefer
from sqlalchemy.orm.query import Query
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class ModelQuery(Query):
@ -22,10 +22,10 @@ class ModelQuery(Query):
self.request = request
# can only filter deleted items if the table has an 'is_deleted' column
self.filter_deleted = bool('is_deleted' in model_cls.__table__.columns)
self.filter_deleted = bool("is_deleted" in model_cls.__table__.columns)
# can only filter removed items if the table has an 'is_removed' column
self.filter_removed = bool('is_removed' in model_cls.__table__.columns)
self.filter_removed = bool("is_removed" in model_cls.__table__.columns)
def __iter__(self) -> Iterator[ModelType]:
"""Iterate over the (processed) results of the query.
@ -36,11 +36,11 @@ class ModelQuery(Query):
results = super().__iter__()
return iter([self._process_result(result) for result in results])
def _attach_extra_data(self) -> 'ModelQuery':
def _attach_extra_data(self) -> "ModelQuery":
"""Override to attach extra data to query before execution."""
return self
def _finalize(self) -> 'ModelQuery':
def _finalize(self) -> "ModelQuery":
"""Finalize the query before it's executed."""
# pylint: disable=protected-access
@ -49,14 +49,13 @@ class ModelQuery(Query):
# is potentially dangerous, but should be fine with the existing
# straightforward usage patterns.
return (
self
.enable_assertions(False)
self.enable_assertions(False)
._attach_extra_data()
._filter_deleted_if_necessary()
._filter_removed_if_necessary()
)
def _before_compile_listener(self) -> 'ModelQuery':
def _before_compile_listener(self) -> "ModelQuery":
"""Do any final adjustments to the query before it's compiled.
Note that this method cannot be overridden by subclasses because of
@ -65,21 +64,21 @@ class ModelQuery(Query):
"""
return self._finalize()
def _filter_deleted_if_necessary(self) -> 'ModelQuery':
def _filter_deleted_if_necessary(self) -> "ModelQuery":
"""Filter out deleted rows unless they were explicitly included."""
if not self.filter_deleted:
return self
return self.filter(self.model_cls.is_deleted == False) # noqa
def _filter_removed_if_necessary(self) -> 'ModelQuery':
def _filter_removed_if_necessary(self) -> "ModelQuery":
"""Filter out removed rows unless they were explicitly included."""
if not self.filter_removed:
return self
return self.filter(self.model_cls.is_removed == False) # noqa
def lock_based_on_request_method(self) -> 'ModelQuery':
def lock_based_on_request_method(self) -> "ModelQuery":
"""Lock the rows if request method implies it's needed (generative).
Applying this function to a query will cause the database to acquire
@ -90,37 +89,37 @@ class ModelQuery(Query):
Note that POST is specifically not included, because the item being
POSTed to is not usually modified in a "dangerous" way as a result.
"""
if self.request.method in {'DELETE', 'PATCH', 'PUT'}:
if self.request.method in {"DELETE", "PATCH", "PUT"}:
return self.with_for_update(of=self.model_cls)
return self
def include_deleted(self) -> 'ModelQuery':
def include_deleted(self) -> "ModelQuery":
"""Specify that deleted rows should be included (generative)."""
self.filter_deleted = False
return self
def include_removed(self) -> 'ModelQuery':
def include_removed(self) -> "ModelQuery":
"""Specify that removed rows should be included (generative)."""
self.filter_removed = False
return self
def join_all_relationships(self) -> 'ModelQuery':
def join_all_relationships(self) -> "ModelQuery":
"""Eagerly join all lazy relationships (generative).
This is useful for being able to load an item "fully" in a single
query and avoid needing to make additional queries for related items.
"""
# pylint: disable=no-member
self = self.options(Load(self.model_cls).joinedload('*'))
self = self.options(Load(self.model_cls).joinedload("*"))
return self
def undefer_all_columns(self) -> 'ModelQuery':
def undefer_all_columns(self) -> "ModelQuery":
"""Undefer all columns (generative)."""
self = self.options(undefer('*'))
self = self.options(undefer("*"))
return self
@ -134,7 +133,7 @@ class ModelQuery(Query):
# before the query executes
event.listen(
ModelQuery,
'before_compile',
"before_compile",
ModelQuery._before_compile_listener, # pylint: disable=protected-access
retval=True,
)

14
tildes/tildes/models/pagination.py

@ -9,7 +9,7 @@ from tildes.lib.id import id_to_id36, id36_to_id
from .model_query import ModelQuery
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class PaginatedQuery(ModelQuery):
@ -18,7 +18,7 @@ class PaginatedQuery(ModelQuery):
def __init__(self, model_cls: Any, request: Request) -> None:
"""Initialize a PaginatedQuery for the specified model and request."""
if len(model_cls.__table__.primary_key) > 1:
raise TypeError('Only single-col primary key tables are supported')
raise TypeError("Only single-col primary key tables are supported")
super().__init__(model_cls, request)
@ -75,7 +75,7 @@ class PaginatedQuery(ModelQuery):
"""
return bool(self.before_id)
def after_id36(self, id36: str) -> 'PaginatedQuery':
def after_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results after an id36 (generative)."""
if self.before_id:
raise ValueError("Can't set both before and after restrictions")
@ -84,7 +84,7 @@ class PaginatedQuery(ModelQuery):
return self
def before_id36(self, id36: str) -> 'PaginatedQuery':
def before_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results before an id36 (generative)."""
if self.after_id:
raise ValueError("Can't set both before and after restrictions")
@ -93,7 +93,7 @@ class PaginatedQuery(ModelQuery):
return self
def _apply_before_or_after(self) -> 'PaginatedQuery':
def _apply_before_or_after(self) -> "PaginatedQuery":
"""Apply the "before" or "after" restrictions if necessary."""
if not (self.after_id or self.before_id):
return self
@ -132,7 +132,7 @@ class PaginatedQuery(ModelQuery):
return query
def _finalize(self) -> 'PaginatedQuery':
def _finalize(self) -> "PaginatedQuery":
"""Finalize the query before execution."""
query = super()._finalize()
@ -152,7 +152,7 @@ class PaginatedQuery(ModelQuery):
return query
def get_page(self, per_page: int) -> 'PaginatedResults':
def get_page(self, per_page: int) -> "PaginatedResults":
"""Get a page worth of results from the query (`per page` items)."""
return PaginatedResults(self, per_page)

174
tildes/tildes/models/topic/topic.py

@ -31,17 +31,14 @@ from tildes.metrics import incr_counter
from tildes.models import DatabaseModel
from tildes.models.group import Group
from tildes.models.user import User
from tildes.schemas.topic import (
TITLE_MAX_LENGTH,
TopicSchema,
)
from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
# edits inside this period after creation will not mark the topic as edited
EDIT_GRACE_PERIOD = timedelta(minutes=5)
# special tags to put at the front of the tag list
SPECIAL_TAGS = ['nsfw', 'spoiler']
SPECIAL_TAGS = ["nsfw", "spoiler"]
class Topic(DatabaseModel):
@ -64,76 +61,67 @@ class Topic(DatabaseModel):
schema_class = TopicSchema
__tablename__ = 'topics'
__tablename__ = "topics"
topic_id: int = Column(Integer, primary_key=True)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
index=True,
Integer, ForeignKey("groups.group_id"), nullable=False, index=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_activity_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
title: str = Column(
Text,
CheckConstraint(
f'LENGTH(title) <= {TITLE_MAX_LENGTH}',
name='title_length',
),
CheckConstraint(f"LENGTH(title) <= {TITLE_MAX_LENGTH}", name="title_length"),
nullable=False,
)
topic_type: TopicType = Column(
ENUM(TopicType), nullable=False, server_default='TEXT')
_markdown: Optional[str] = deferred(Column('markdown', Text))
ENUM(TopicType), nullable=False, server_default="TEXT"
)
_markdown: Optional[str] = deferred(Column("markdown", Text))
rendered_html: Optional[str] = Column(Text)
link: Optional[str] = Column(Text)
content_metadata: Dict[str, Any] = Column(JSONB)
num_comments: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_comments: int = Column(Integer, nullable=False, server_default="0", index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
_tags: List[Ltree] = Column(
'tags', ArrayOfLtree, nullable=False, server_default='{}')
is_official: bool = Column(Boolean, nullable=False, server_default='false')
is_locked: bool = Column(Boolean, nullable=False, server_default='false')
"tags", ArrayOfLtree, nullable=False, server_default="{}"
)
is_official: bool = Column(Boolean, nullable=False, server_default="false")
is_locked: bool = Column(Boolean, nullable=False, server_default="false")
user: User = relationship('User', lazy=False, innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", lazy=False, innerjoin=True)
group: Group = relationship("Group", innerjoin=True)
# Create a GiST index on the tags column
__table_args__ = (
Index('ix_topics_tags_gist', _tags, postgresql_using='gist'),
)
__table_args__ = (Index("ix_topics_tags_gist", _tags, postgresql_using="gist"),)
@hybrid_property
def markdown(self) -> Optional[str]:
"""Return the topic's markdown."""
if not self.is_text_type:
raise AttributeError('Only text topics have markdown')
raise AttributeError("Only text topics have markdown")
return self._markdown
@ -141,7 +129,7 @@ class Topic(DatabaseModel):
def markdown(self, new_markdown: str) -> None:
"""Set the topic's markdown and render its HTML."""
if not self.is_text_type:
raise AttributeError('Can only set markdown for text topics')
raise AttributeError("Can only set markdown for text topics")
if new_markdown == self.markdown:
return
@ -149,21 +137,19 @@ class Topic(DatabaseModel):
self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now()
@hybrid_property
def tags(self) -> List[str]:
"""Return the topic's tags."""
sorted_tags = [str(tag).replace('_', ' ') for tag in self._tags]
sorted_tags = [str(tag).replace("_", " ") for tag in self._tags]
# move special tags in front
# reverse so that tags at the start of the list appear first
for tag in reversed(SPECIAL_TAGS):
if tag in sorted_tags:
sorted_tags.insert(
0, sorted_tags.pop(sorted_tags.index(tag)))
sorted_tags.insert(0, sorted_tags.pop(sorted_tags.index(tag)))
return sorted_tags
@ -176,12 +162,7 @@ class Topic(DatabaseModel):
return f'<Topic "{self.title}" ({self.topic_id})>'
@classmethod
def _create_base_topic(
cls,
group: Group,
author: User,
title: str,
) -> 'Topic':
def _create_base_topic(cls, group: Group, author: User, title: str) -> "Topic":
"""Create the "base" for a new topic."""
new_topic = cls()
new_topic.group_id = group.group_id
@ -192,35 +173,27 @@ class Topic(DatabaseModel):
@classmethod
def create_text_topic(
cls,
group: Group,
author: User,
title: str,
markdown: str = '',
) -> 'Topic':
cls, group: Group, author: User, title: str, markdown: str = ""
) -> "Topic":
"""Create a new text topic."""
new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.TEXT
new_topic.markdown = markdown
incr_counter('topics', type='text')
incr_counter("topics", type="text")
return new_topic
@classmethod
def create_link_topic(
cls,
group: Group,
author: User,
title: str,
link: str,
) -> 'Topic':
cls, group: Group, author: User, title: str, link: str
) -> "Topic":
"""Create a new link topic."""
new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.LINK
new_topic.link = link
incr_counter('topics', type='link')
incr_counter("topics", type="link")
return new_topic
@ -230,74 +203,74 @@ class Topic(DatabaseModel):
# deleted topics allow "general" viewing, but nothing else
if self.is_deleted:
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
acl.append(DENY_ALL)
# view:
# - everyone gets "general" viewing permission for all topics
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# view_author:
# - removed topics' author is only visible to the author and admins
# - otherwise, everyone can view the author
if self.is_removed:
acl.append((Allow, 'admin', 'view_author'))
acl.append((Allow, self.user_id, 'view_author'))
acl.append((Deny, Everyone, 'view_author'))
acl.append((Allow, "admin", "view_author"))
acl.append((Allow, self.user_id, "view_author"))
acl.append((Deny, Everyone, "view_author"))
acl.append((Allow, Everyone, 'view_author'))
acl.append((Allow, Everyone, "view_author"))
# view_content:
# - removed topics' content is only visible to the author and admins
# - otherwise, everyone can view the content
if self.is_removed:
acl.append((Allow, 'admin', 'view_content'))
acl.append((Allow, self.user_id, 'view_content'))
acl.append((Deny, Everyone, 'view_content'))
acl.append((Allow, "admin", "view_content"))
acl.append((Allow, self.user_id, "view_content"))
acl.append((Deny, Everyone, "view_content"))
acl.append((Allow, Everyone, 'view_content'))
acl.append((Allow, Everyone, "view_content"))
# vote:
# - removed topics can't be voted on by anyone
# - otherwise, logged-in users except the author can vote
if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# comment:
# - removed topics can only be commented on by admins
# - locked topics can only be commented on by admins
# - otherwise, logged-in users can comment
if self.is_removed:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
if self.is_locked:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
acl.append((Allow, Authenticated, 'comment'))
acl.append((Allow, Authenticated, "comment"))
# edit:
# - only text topics can be edited, only by the author
if self.is_text_type:
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete:
# - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# tag:
# - only the author and admins can tag topics
acl.append((Allow, self.user_id, 'tag'))
acl.append((Allow, 'admin', 'tag'))
acl.append((Allow, self.user_id, "tag"))
acl.append((Allow, "admin", "tag"))
# admin tools
acl.append((Allow, 'admin', 'lock'))
acl.append((Allow, 'admin', 'move'))
acl.append((Allow, 'admin', 'edit_title'))
acl.append((Allow, "admin", "lock"))
acl.append((Allow, "admin", "move"))
acl.append((Allow, "admin", "edit_title"))
acl.append(DENY_ALL)
@ -316,7 +289,7 @@ class Topic(DatabaseModel):
@property
def permalink(self) -> str:
"""Return the permalink for this topic."""
return f'/~{self.group.path}/{self.topic_id36}/{self.url_slug}'
return f"/~{self.group.path}/{self.topic_id36}/{self.url_slug}"
@property
def is_text_type(self) -> bool:
@ -332,27 +305,26 @@ class Topic(DatabaseModel):
def type_for_display(self) -> str:
"""Return a string of the topic's type, suitable for display."""
if self.is_text_type:
return 'Text'
return "Text"
elif self.is_link_type:
return 'Link'
return "Link"
return 'Topic'
return "Topic"
@property
def link_domain(self) -> str:
"""Return the link's domain (for link topics only)."""
if not self.is_link_type or not self.link:
raise ValueError('Non-link topics do not have a domain')
raise ValueError("Non-link topics do not have a domain")
# get the domain from the content metadata if possible, but fall back
# to just parsing it from the link if it's not present
return (self.get_content_metadata('domain')
or get_domain_from_url(self.link))
return self.get_content_metadata("domain") or get_domain_from_url(self.link)
@property
def is_spoiler(self) -> bool:
"""Return whether the topic is marked as a spoiler."""
return 'spoiler' in self.tags
return "spoiler" in self.tags
def get_content_metadata(self, key: str) -> Any:
"""Get a piece of content metadata "safely".
@ -371,13 +343,13 @@ class Topic(DatabaseModel):
metadata_strings = []
if self.is_text_type:
word_count = self.get_content_metadata('word_count')
word_count = self.get_content_metadata("word_count")
if word_count is not None:
if word_count == 1:
metadata_strings.append('1 word')
metadata_strings.append("1 word")
else:
metadata_strings.append(f'{word_count} words')
metadata_strings.append(f"{word_count} words")
elif self.is_link_type:
metadata_strings.append(f'{self.link_domain}')
metadata_strings.append(f"{self.link_domain}")
return ', '.join(metadata_strings)
return ", ".join(metadata_strings)

45
tildes/tildes/models/topic/topic_query.py

@ -28,7 +28,7 @@ class TopicQuery(PaginatedQuery):
"""
super().__init__(Topic, request)
def _attach_extra_data(self) -> 'TopicQuery':
def _attach_extra_data(self) -> "TopicQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
@ -36,7 +36,7 @@ class TopicQuery(PaginatedQuery):
# pylint: disable=protected-access
return self._attach_vote_data()._attach_visit_data()
def _attach_vote_data(self) -> 'TopicQuery':
def _attach_vote_data(self) -> "TopicQuery":
"""Add a subquery to include whether the user has voted."""
vote_subquery = (
self.request.query(TopicVote)
@ -45,24 +45,25 @@ class TopicQuery(PaginatedQuery):
TopicVote.user == self.request.user,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)
def _attach_visit_data(self) -> 'TopicQuery':
def _attach_visit_data(self) -> "TopicQuery":
"""Join the data related to the user's last visit to the topic(s)."""
if self.request.user.track_comment_visits:
query = self.outerjoin(TopicVisit, and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
))
query = query.add_columns(
TopicVisit.visit_time, TopicVisit.num_comments)
query = self.outerjoin(
TopicVisit,
and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
),
)
query = query.add_columns(TopicVisit.visit_time, TopicVisit.num_comments)
else:
# if the user has the feature disabled, just add literal NULLs
query = self.add_columns(
null().label('visit_time'),
null().label('num_comments'),
null().label("visit_time"), null().label("num_comments")
)
return query
@ -90,10 +91,8 @@ class TopicQuery(PaginatedQuery):
return topic
def apply_sort_option(
self,
sort: TopicSortOption,
desc: bool = True,
) -> 'TopicQuery':
self, sort: TopicSortOption, desc: bool = True
) -> "TopicQuery":
"""Apply a TopicSortOption sorting method (generative)."""
if sort == TopicSortOption.VOTES:
self._sort_column = Topic.num_votes
@ -108,18 +107,16 @@ class TopicQuery(PaginatedQuery):
return self
def inside_groups(self, groups: Sequence[Group]) -> 'TopicQuery':
def inside_groups(self, groups: Sequence[Group]) -> "TopicQuery":
"""Restrict the topics to inside specific groups (generative)."""
query_paths = [group.path for group in groups]
subgroup_subquery = (
self.request.db_session.query(Group.group_id)
.filter(Group.path.descendant_of(query_paths))
subgroup_subquery = self.request.db_session.query(Group.group_id).filter(
Group.path.descendant_of(query_paths)
)
return self.filter(
Topic.group_id.in_(subgroup_subquery)) # type: ignore
return self.filter(Topic.group_id.in_(subgroup_subquery)) # type: ignore
def inside_time_period(self, period: SimpleHoursPeriod) -> 'TopicQuery':
def inside_time_period(self, period: SimpleHoursPeriod) -> "TopicQuery":
"""Restrict the topics to inside a time period (generative)."""
# if the time period is too long, this will crash by creating a
# datetime outside the valid range - catch that and just don't filter
@ -131,7 +128,7 @@ class TopicQuery(PaginatedQuery):
return self.filter(Topic.created_time > start_time)
def has_tag(self, tag: Ltree) -> 'TopicQuery':
def has_tag(self, tag: Ltree) -> "TopicQuery":
"""Restrict the topics to ones with a specific tag (generative)."""
# casting tag to string really shouldn't be necessary, but some kind of
# strange interaction seems to be happening with the ArrayOfLtree

26
tildes/tildes/models/topic/topic_visit.py

@ -28,28 +28,19 @@ class TopicVisit(DatabaseModel):
visits to the topic that were after it was posted.
"""
__tablename__ = 'topic_visits'
__tablename__ = "topic_visits"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
)
visit_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
)
visit_time: datetime = Column(TIMESTAMP(timezone=True), nullable=False)
num_comments: int = Column(Integer, nullable=False)
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
@classmethod
def generate_insert_statement(cls, user: User, topic: Topic) -> Insert:
@ -65,9 +56,6 @@ class TopicVisit(DatabaseModel):
)
.on_conflict_do_update(
constraint=cls.__table__.primary_key,
set_={
'visit_time': visit_time,
'num_comments': topic.num_comments,
},
set_={"visit_time": visit_time, "num_comments": topic.num_comments},
)
)

20
tildes/tildes/models/topic/topic_vote.py

@ -21,33 +21,27 @@ class TopicVote(DatabaseModel):
column for the relevant topic.
"""
__tablename__ = 'topic_votes'
__tablename__ = "topic_votes"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
def __init__(self, user: User, topic: Topic) -> None:
"""Create a new vote on a topic."""
self.user = user
self.topic = topic
incr_counter('votes', target_type='topic')
incr_counter("votes", target_type="topic")

70
tildes/tildes/models/user/user.py

@ -49,7 +49,7 @@ class User(DatabaseModel):
schema_class = UserSchema
__tablename__ = 'users'
__tablename__ = "users"
user_id: int = Column(Integer, primary_key=True)
username: str = Column(CIText, nullable=False, unique=True)
@ -59,9 +59,8 @@ class User(DatabaseModel):
Column(
Text,
CheckConstraint(
'LENGTH(email_address_note) <= '
f'{EMAIL_ADDRESS_NOTE_MAX_LENGTH}',
name='email_address_note_length',
f"LENGTH(email_address_note) <= {EMAIL_ADDRESS_NOTE_MAX_LENGTH}",
name="email_address_note_length",
),
)
)
@ -69,44 +68,35 @@ class User(DatabaseModel):
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
num_unread_messages: int = Column(
Integer, nullable=False, server_default='0')
num_unread_notifications: int = Column(
Integer, nullable=False, server_default='0')
inviter_id: int = Column(Integer, ForeignKey('users.user_id'))
invite_codes_remaining: int = Column(
Integer, nullable=False, server_default='0')
track_comment_visits: bool = Column(
Boolean, nullable=False, server_default='false')
num_unread_messages: int = Column(Integer, nullable=False, server_default="0")
num_unread_notifications: int = Column(Integer, nullable=False, server_default="0")
inviter_id: int = Column(Integer, ForeignKey("users.user_id"))
invite_codes_remaining: int = Column(Integer, nullable=False, server_default="0")
track_comment_visits: bool = Column(Boolean, nullable=False, server_default="false")
auto_mark_notifications_read: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_external: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_internal: bool = Column(
Boolean, nullable=False, server_default='false')
open_new_tab_text: bool = Column(
Boolean, nullable=False, server_default='false')
is_banned: bool = Column(Boolean, nullable=False, server_default='false')
is_admin: bool = Column(Boolean, nullable=False, server_default='false')
home_default_order: Optional[TopicSortOption] = Column(
ENUM(TopicSortOption))
Boolean, nullable=False, server_default="false"
)
open_new_tab_text: bool = Column(Boolean, nullable=False, server_default="false")
is_banned: bool = Column(Boolean, nullable=False, server_default="false")
is_admin: bool = Column(Boolean, nullable=False, server_default="false")
home_default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
home_default_period: Optional[str] = Column(Text)
_filtered_topic_tags: List[Ltree] = Column(
'filtered_topic_tags',
ArrayOfLtree,
nullable=False,
server_default='{}',
"filtered_topic_tags", ArrayOfLtree, nullable=False, server_default="{}"
)
@hybrid_property
def filtered_topic_tags(self) -> List[str]:
"""Return the user's list of filtered topic tags."""
return [
str(tag).replace('_', ' ')
for tag in self._filtered_topic_tags
]
return [str(tag).replace("_", " ") for tag in self._filtered_topic_tags]
@filtered_topic_tags.setter # type: ignore
def filtered_topic_tags(self, new_tags: List[str]) -> None:
@ -114,7 +104,7 @@ class User(DatabaseModel):
def __repr__(self) -> str:
"""Display the user's username and ID as its repr format."""
return f'<User {self.username} ({self.user_id})>'
return f"<User {self.username} ({self.user_id})>"
def __str__(self) -> str:
"""Use the username for the string representation."""
@ -131,12 +121,12 @@ class User(DatabaseModel):
# view:
# - everyone can view all users
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# message:
# - anyone can message a user except themself
acl.append((Deny, self.user_id, 'message'))
acl.append((Allow, Authenticated, 'message'))
acl.append((Deny, self.user_id, "message"))
acl.append((Allow, Authenticated, "message"))
# grant the user all other permissions on themself
acl.append((Allow, self.user_id, ALL_PERMISSIONS))
@ -148,13 +138,13 @@ class User(DatabaseModel):
@property
def password(self) -> NoReturn:
"""Return an error since reading the password isn't possible."""
raise AttributeError('Password is write-only')
raise AttributeError("Password is write-only")
@password.setter
def password(self, value: str) -> None:
# need to do manual validation since some password checks depend on
# checking the username at the same time (for similarity)
self.schema.validate({'username': self.username, 'password': value})
self.schema.validate({"username": self.username, "password": value})
self.password_hash = hash_string(value)
@ -165,10 +155,10 @@ class User(DatabaseModel):
def change_password(self, old_password: str, new_password: str) -> None:
"""Change the user's password from the old one to a new one."""
if not self.is_correct_password(old_password):
raise ValueError('Old password was not correct')
raise ValueError("Old password was not correct")
if new_password == old_password:
raise ValueError('New password is the same as old password')
raise ValueError("New password is the same as old password")
# disable mypy on this line because it doesn't handle setters correctly
self.password = new_password # type: ignore
@ -176,7 +166,7 @@ class User(DatabaseModel):
@property
def email_address(self) -> NoReturn:
"""Return an error since reading the email address isn't possible."""
raise AttributeError('Email address is write-only')
raise AttributeError("Email address is write-only")
@email_address.setter
def email_address(self, value: Optional[str]) -> None:

16
tildes/tildes/models/user/user_group_settings.py

@ -15,22 +15,16 @@ from tildes.models.user import User
class UserGroupSettings(DatabaseModel):
"""Model for a user's settings related to a specific group."""
__tablename__ = 'user_group_settings'
__tablename__ = "user_group_settings"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
)
default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
default_period: Optional[str] = Column(Text)
user: User = relationship('User', innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", innerjoin=True)
group: Group = relationship("Group", innerjoin=True)

37
tildes/tildes/models/user/user_invite_code.py

@ -4,14 +4,7 @@ from datetime import datetime
import random
import string
from sqlalchemy import (
CheckConstraint,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import CheckConstraint, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text
from tildes.models import DatabaseModel
@ -21,7 +14,7 @@ from .user import User
class UserInviteCode(DatabaseModel):
"""Model for invite codes that allow new users to register."""
__tablename__ = 'user_invite_codes'
__tablename__ = "user_invite_codes"
# the character set to generate codes using
ALPHABET = string.ascii_uppercase + string.digits
@ -30,33 +23,25 @@ class UserInviteCode(DatabaseModel):
code: str = Column(
Text,
CheckConstraint(
f'LENGTH(code) <= {LENGTH}',
name='code_length',
),
CheckConstraint(f"LENGTH(code) <= {LENGTH}", name="code_length"),
primary_key=True,
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
invitee_id: int = Column(Integer, ForeignKey('users.user_id'))
invitee_id: int = Column(Integer, ForeignKey("users.user_id"))
def __str__(self) -> str:
"""Format the code into a more easily readable version."""
formatted = ''
formatted = ""
for count, char in enumerate(self.code):
# add a dash every 5 chars
if count > 0 and count % 5 == 0:
formatted += '-'
formatted += "-"
formatted += char.upper()
@ -71,7 +56,7 @@ class UserInviteCode(DatabaseModel):
self.user_id = user.user_id
code_chars = random.choices(self.ALPHABET, k=self.LENGTH)
self.code = ''.join(code_chars)
self.code = "".join(code_chars)
@classmethod
def prepare_code_for_lookup(cls, code: str) -> str:
@ -81,9 +66,9 @@ class UserInviteCode(DatabaseModel):
# remove any characters that aren't in the code alphabet (allows
# dashes, spaces, etc. to be used to make the codes more readable)
code = ''.join(letter for letter in code if letter in cls.ALPHABET)
code = "".join(letter for letter in code if letter in cls.ALPHABET)
if len(code) > cls.LENGTH:
raise ValueError('Code is longer than the maximum length')
raise ValueError("Code is longer than the maximum length")
return code

6
tildes/tildes/resources/__init__.py

@ -14,11 +14,7 @@ def get_resource(request: Request, base_query: ModelQuery) -> DatabaseModel:
if not request.user:
raise HTTPForbidden
query = (
base_query
.lock_based_on_request_method()
.join_all_relationships()
)
query = base_query.lock_based_on_request_method().join_all_relationships()
if not request.is_safe_method:
query = query.undefer_all_columns()

16
tildes/tildes/resources/comment.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.comment import CommentSchema
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def comment_by_id36(request: Request, comment_id36: str) -> Comment:
"""Get a comment specified by {comment_id36} in the route (or 404)."""
query = (
@ -25,13 +22,9 @@ def comment_by_id36(request: Request, comment_id36: str) -> Comment:
return get_resource(request, query)
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def notification_by_comment_id36(
request: Request,
comment_id36: str,
request: Request, comment_id36: str
) -> CommentNotification:
"""Get a comment notification specified by {comment_id36} in the route.
@ -43,8 +36,7 @@ def notification_by_comment_id36(
comment_id = id36_to_id(comment_id36)
query = request.query(CommentNotification).filter_by(
user=request.user,
comment_id=comment_id,
user=request.user, comment_id=comment_id
)
return get_resource(request, query)

11
tildes/tildes/resources/group.py

@ -11,8 +11,8 @@ from tildes.schemas.group import GroupSchema
@use_kwargs(
GroupSchema(only=('path',), context={'fix_path_capitalization': True}),
locations=('matchdict',),
GroupSchema(only=("path",), context={"fix_path_capitalization": True}),
locations=("matchdict",),
)
def group_by_path(request: Request, path: str) -> Group:
"""Get a group specified by {group_path} in the route (or 404)."""
@ -20,10 +20,9 @@ def group_by_path(request: Request, path: str) -> Group:
# 301 redirect to the resulting group path. This will happen in cases like
# the original url including capital letters in the group path, where we
# want to redirect to the proper all-lowercase path instead.
if path != request.matchdict['group_path']:
request.matchdict['group_path'] = path
proper_url = request.route_url(
request.matched_route.name, **request.matchdict)
if path != request.matchdict["group_path"]:
request.matchdict["group_path"] = path
proper_url = request.route_url(request.matched_route.name, **request.matchdict)
raise HTTPMovedPermanently(location=proper_url)

11
tildes/tildes/resources/message.py

@ -10,17 +10,14 @@ from tildes.schemas.message import MessageConversationSchema
@use_kwargs(
MessageConversationSchema(only=('conversation_id36',)),
locations=('matchdict',),
MessageConversationSchema(only=("conversation_id36",)), locations=("matchdict",)
)
def message_conversation_by_id36(
request: Request,
conversation_id36: str,
request: Request, conversation_id36: str
) -> MessageConversation:
"""Get a conversation specified by {conversation_id36} in the route."""
query = (
request.query(MessageConversation)
.filter_by(conversation_id=id36_to_id(conversation_id36))
query = request.query(MessageConversation).filter_by(
conversation_id=id36_to_id(conversation_id36)
)
return get_resource(request, query)

9
tildes/tildes/resources/topic.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.topic import TopicSchema
@use_kwargs(
TopicSchema(only=('topic_id36',)),
locations=('matchdict',),
)
@use_kwargs(TopicSchema(only=("topic_id36",)), locations=("matchdict",))
def topic_by_id36(request: Request, topic_id36: str) -> Topic:
"""Get a topic specified by {topic_id36} in the route (or 404)."""
query = (
@ -27,8 +24,8 @@ def topic_by_id36(request: Request, topic_id36: str) -> Topic:
# if there's also a group specified in the route, check that it's the same
# group as the topic was posted in, otherwise redirect to correct group
if 'group_path' in request.matchdict:
path_from_route = request.matchdict['group_path'].lower()
if "group_path" in request.matchdict:
path_from_route = request.matchdict["group_path"].lower()
if path_from_route != topic.group.path:
raise HTTPFound(topic.permalink)

5
tildes/tildes/resources/user.py

@ -8,10 +8,7 @@ from tildes.resources import get_resource
from tildes.schemas.user import UserSchema
@use_kwargs(
UserSchema(only=('username',)),
locations=('matchdict',),
)
@use_kwargs(UserSchema(only=("username",)), locations=("matchdict",))
def user_by_username(request: Request, username: str) -> User:
"""Get a user specified by {username} in the route or 404 if not found."""
query = request.query(User).filter(User.username == username)

160
tildes/tildes/routes.py

@ -6,10 +6,7 @@ from pyramid.config import Configurator
from pyramid.request import Request
from pyramid.security import Allow, Authenticated
from tildes.resources.comment import (
comment_by_id36,
notification_by_comment_id36,
)
from tildes.resources.comment import comment_by_id36, notification_by_comment_id36
from tildes.resources.group import group_by_path
from tildes.resources.message import message_conversation_by_id36
from tildes.resources.topic import topic_by_id36
@ -18,172 +15,133 @@ from tildes.resources.user import user_by_username
def includeme(config: Configurator) -> None:
"""Set up application routes."""
config.add_route('home', '/')
config.add_route("home", "/")
config.add_route('groups', '/groups')
config.add_route("groups", "/groups")
config.add_route('login', '/login')
config.add_route('logout', '/logout', factory=LoggedInFactory)
config.add_route("login", "/login")
config.add_route("logout", "/logout", factory=LoggedInFactory)
config.add_route('register', '/register')
config.add_route("register", "/register")
config.add_route('group', '/~{group_path}', factory=group_by_path)
config.add_route(
'new_topic', '/~{group_path}/new_topic', factory=group_by_path)
config.add_route("group", "/~{group_path}", factory=group_by_path)
config.add_route("new_topic", "/~{group_path}/new_topic", factory=group_by_path)
config.add_route(
'group_topics', '/~{group_path}/topics', factory=group_by_path)
config.add_route("group_topics", "/~{group_path}/topics", factory=group_by_path)
config.add_route(
'topic', '/~{group_path}/{topic_id36}*title', factory=topic_by_id36)
"topic", "/~{group_path}/{topic_id36}*title", factory=topic_by_id36
)
config.add_route('user', '/user/{username}', factory=user_by_username)
config.add_route("user", "/user/{username}", factory=user_by_username)
config.add_route("notifications", "/notifications", factory=LoggedInFactory)
config.add_route(
'notifications', '/notifications', factory=LoggedInFactory)
config.add_route(
'notifications_unread',
'/notifications/unread',
factory=LoggedInFactory,
"notifications_unread", "/notifications/unread", factory=LoggedInFactory
)
config.add_route('messages', '/messages', factory=LoggedInFactory)
config.add_route(
'messages_sent', '/messages/sent', factory=LoggedInFactory)
config.add_route("messages", "/messages", factory=LoggedInFactory)
config.add_route("messages_sent", "/messages/sent", factory=LoggedInFactory)
config.add_route("messages_unread", "/messages/unread", factory=LoggedInFactory)
config.add_route(
'messages_unread', '/messages/unread', factory=LoggedInFactory)
config.add_route(
'message_conversation',
'/messages/conversations/{conversation_id36}',
"message_conversation",
"/messages/conversations/{conversation_id36}",
factory=message_conversation_by_id36,
)
config.add_route(
'new_message',
'/user/{username}/new_message',
factory=user_by_username,
"new_message", "/user/{username}/new_message", factory=user_by_username
)
config.add_route(
'user_messages',
'/user/{username}/messages',
factory=user_by_username,
"user_messages", "/user/{username}/messages", factory=user_by_username
)
config.add_route('settings', '/settings', factory=LoggedInFactory)
config.add_route("settings", "/settings", factory=LoggedInFactory)
config.add_route(
'settings_account_recovery',
'/settings/account_recovery',
"settings_account_recovery",
"/settings/account_recovery",
factory=LoggedInFactory,
)
config.add_route(
'settings_comment_visits',
'/settings/comment_visits',
factory=LoggedInFactory,
"settings_comment_visits", "/settings/comment_visits", factory=LoggedInFactory
)
config.add_route("settings_filters", "/settings/filters", factory=LoggedInFactory)
config.add_route(
'settings_filters', '/settings/filters', factory=LoggedInFactory)
config.add_route(
'settings_password_change',
'/settings/password_change',
factory=LoggedInFactory,
"settings_password_change", "/settings/password_change", factory=LoggedInFactory
)
config.add_route('invite', '/invite', factory=LoggedInFactory)
config.add_route("invite", "/invite", factory=LoggedInFactory)
# Route to expose metrics to Prometheus
config.add_route('metrics', '/metrics')
config.add_route("metrics", "/metrics")
# Route for Stripe donation processing page (POSTed to from docs site)
config.add_route('donate_stripe', '/donate_stripe')
config.add_route("donate_stripe", "/donate_stripe")
add_intercooler_routes(config)
def add_intercooler_routes(config: Configurator) -> None:
"""Set up all routes for the (internal-use) Intercooler API endpoints."""
def add_ic_route(name: str, path: str, **kwargs: Any) -> None:
"""Add route with intercooler name prefix, base path, header check."""
name = 'ic_' + name
path = '/api/web' + path
config.add_route(
name,
path,
header='X-IC-Request:true',
**kwargs)
name = "ic_" + name
path = "/api/web" + path
config.add_route(name, path, header="X-IC-Request:true", **kwargs)
add_ic_route(
'group_subscribe',
'/group/{group_path}/subscribe',
factory=group_by_path,
"group_subscribe", "/group/{group_path}/subscribe", factory=group_by_path
)
add_ic_route(
'group_user_settings',
'/group/{group_path}/user_settings',
"group_user_settings",
"/group/{group_path}/user_settings",
factory=group_by_path,
)
add_ic_route('topic', '/topics/{topic_id36}', factory=topic_by_id36)
add_ic_route("topic", "/topics/{topic_id36}", factory=topic_by_id36)
add_ic_route(
'topic_comments',
'/topics/{topic_id36}/comments',
factory=topic_by_id36,
)
add_ic_route(
'topic_group', '/topics/{topic_id36}/group', factory=topic_by_id36)
add_ic_route(
'topic_lock', '/topics/{topic_id36}/lock', factory=topic_by_id36)
add_ic_route(
'topic_title', '/topics/{topic_id36}/title', factory=topic_by_id36)
add_ic_route(
'topic_vote', '/topics/{topic_id36}/vote', factory=topic_by_id36)
add_ic_route(
'topic_tags',
'/topics/{topic_id36}/tags',
factory=topic_by_id36,
"topic_comments", "/topics/{topic_id36}/comments", factory=topic_by_id36
)
add_ic_route("topic_group", "/topics/{topic_id36}/group", factory=topic_by_id36)
add_ic_route("topic_lock", "/topics/{topic_id36}/lock", factory=topic_by_id36)
add_ic_route("topic_title", "/topics/{topic_id36}/title", factory=topic_by_id36)
add_ic_route("topic_vote", "/topics/{topic_id36}/vote", factory=topic_by_id36)
add_ic_route("topic_tags", "/topics/{topic_id36}/tags", factory=topic_by_id36)
add_ic_route("comment", "/comments/{comment_id36}", factory=comment_by_id36)
add_ic_route(
'comment', '/comments/{comment_id36}', factory=comment_by_id36)
add_ic_route(
'comment_replies',
'/comments/{comment_id36}/replies',
factory=comment_by_id36,
"comment_replies", "/comments/{comment_id36}/replies", factory=comment_by_id36
)
add_ic_route(
'comment_vote',
'/comments/{comment_id36}/vote',
factory=comment_by_id36,
"comment_vote", "/comments/{comment_id36}/vote", factory=comment_by_id36
)
add_ic_route(
'comment_tag',
'/comments/{comment_id36}/tags/{name}',
factory=comment_by_id36,
"comment_tag", "/comments/{comment_id36}/tags/{name}", factory=comment_by_id36
)
add_ic_route(
'comment_mark_read',
'/comments/{comment_id36}/mark_read',
"comment_mark_read",
"/comments/{comment_id36}/mark_read",
factory=notification_by_comment_id36,
)
add_ic_route(
'message_conversation_replies',
'/messages/conversations/{conversation_id36}/replies',
"message_conversation_replies",
"/messages/conversations/{conversation_id36}/replies",
factory=message_conversation_by_id36,
)
add_ic_route('user', '/user/{username}', factory=user_by_username)
add_ic_route("user", "/user/{username}", factory=user_by_username)
add_ic_route(
'user_filtered_topic_tags',
'/user/{username}/filtered_topic_tags',
"user_filtered_topic_tags",
"/user/{username}/filtered_topic_tags",
factory=user_by_username,
)
add_ic_route(
'user_invite_code',
'/user/{username}/invite_code',
factory=user_by_username,
"user_invite_code", "/user/{username}/invite_code", factory=user_by_username
)
add_ic_route(
'user_default_listing_options',
'/user/{username}/default_listing_options',
"user_default_listing_options",
"/user/{username}/default_listing_options",
factory=user_by_username,
)
@ -196,7 +154,7 @@ class LoggedInFactory:
checking access to a specific resource (such as a topic or message).
"""
__acl__ = ((Allow, Authenticated, 'view'),)
__acl__ = ((Allow, Authenticated, "view"),)
def __init__(self, request: Request) -> None:
"""Initialize - no-op, but needs to take the request as an arg."""

45
tildes/tildes/schemas/fields.py

@ -16,12 +16,7 @@ from tildes.lib.string import simplify_string
class Enum(Field):
"""Field for a native Python Enum (or subclasses)."""
def __init__(
self,
enum_class: Type = None,
*args: Any,
**kwargs: Any,
) -> None:
def __init__(self, enum_class: Type = None, *args: Any, **kwargs: Any) -> None:
"""Initialize the field with an optional enum class."""
super().__init__(*args, **kwargs)
self._enum_class = enum_class
@ -33,12 +28,12 @@ class Enum(Field):
def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum:
"""Deserialize a string to the enum member with that name."""
if not self._enum_class:
raise ValidationError('Cannot deserialize with no enum class.')
raise ValidationError("Cannot deserialize with no enum class.")
try:
return self._enum_class[value.upper()]
except KeyError:
raise ValidationError('Invalid enum member')
raise ValidationError("Invalid enum member")
class ID36(String):
@ -56,25 +51,19 @@ class ShortTimePeriod(Field):
"""
def _deserialize(
self,
value: str,
attr: str,
data: dict,
self, value: str, attr: str, data: dict
) -> Optional[SimpleHoursPeriod]:
"""Deserialize to a SimpleHoursPeriod object."""
if value == 'all':
if value == "all":
return None
try:
return SimpleHoursPeriod.from_short_form(value)
except ValueError:
raise ValidationError('Invalid time period')
raise ValidationError("Invalid time period")
def _serialize(
self,
value: Optional[SimpleHoursPeriod],
attr: str,
obj: object,
self, value: Optional[SimpleHoursPeriod], attr: str, obj: object
) -> Optional[str]:
"""Serialize the value to the "short form" string."""
if not value:
@ -100,11 +89,11 @@ class Markdown(Field):
super()._validate(value)
if value.isspace():
raise ValidationError('Cannot be entirely whitespace.')
raise ValidationError("Cannot be entirely whitespace.")
def _deserialize(self, value: str, attr: str, data: dict) -> str:
"""Deserialize the string, removing carriage returns in the process."""
value = value.replace('\r', '')
value = value.replace("\r", "")
return value
@ -145,23 +134,13 @@ class SimpleString(Field):
class Ltree(Field):
"""Field for postgresql ltree type."""
def _serialize(
self,
value: sqlalchemy_utils.Ltree,
attr: str,
obj: object,
) -> str:
def _serialize(self, value: sqlalchemy_utils.Ltree, attr: str, obj: object) -> str:
"""Serialize the Ltree value - use the (string) path."""
return value.path
def _deserialize(
self,
value: str,
attr: str,
data: dict,
) -> sqlalchemy_utils.Ltree:
def _deserialize(self, value: str, attr: str, data: dict) -> sqlalchemy_utils.Ltree:
"""Deserialize a string path to an Ltree object."""
try:
return sqlalchemy_utils.Ltree(value)
except (TypeError, ValueError):
raise ValidationError('Invalid path')
raise ValidationError("Invalid path")

27
tildes/tildes/schemas/group.py

@ -15,11 +15,13 @@ from tildes.schemas.fields import Ltree, SimpleString
# - must end with a number or lowercase letter
# - the middle can contain numbers, lowercase letters, and underscores
# Note: this regex does not contain any length checks, must be done separately
# fmt: off
GROUP_PATH_ELEMENT_VALID_REGEX = re.compile(
'^[a-z0-9]' # start
'([a-z0-9_]*' # middle
'[a-z0-9])?$', # end
"^[a-z0-9]" # start
"([a-z0-9_]*" # middle
"[a-z0-9])?$" # end
)
# fmt: on
SHORT_DESCRIPTION_MAX_LENGTH = 200
@ -27,21 +29,20 @@ SHORT_DESCRIPTION_MAX_LENGTH = 200
class GroupSchema(Schema):
"""Marshmallow schema for groups."""
path = Ltree(required=True, load_from='group_path')
path = Ltree(required=True, load_from="group_path")
created_time = DateTime(dump_only=True)
short_description = SimpleString(
max_length=SHORT_DESCRIPTION_MAX_LENGTH,
allow_none=True,
max_length=SHORT_DESCRIPTION_MAX_LENGTH, allow_none=True
)
@pre_load
def prepare_path(self, data: dict) -> dict:
"""Prepare the path value before it's validated."""
if not self.context.get('fix_path_capitalization'):
if not self.context.get("fix_path_capitalization"):
return data
# path can also be loaded from group_path, so we need to check both
keys = ('path', 'group_path')
keys = ("path", "group_path")
for key in keys:
if key in data and isinstance(data[key], str):
@ -49,17 +50,17 @@ class GroupSchema(Schema):
return data
@validates('path')
@validates("path")
def validate_path(self, value: sqlalchemy_utils.Ltree) -> None:
"""Validate the path field, raising an error if an issue exists."""
# check each element for length and against validity regex
path_elements = value.path.split('.')
path_elements = value.path.split(".")
for element in path_elements:
if len(element) > 256:
raise ValidationError('Path element %s is too long' % element)
raise ValidationError("Path element %s is too long" % element)
if not GROUP_PATH_ELEMENT_VALID_REGEX.match(element):
raise ValidationError('Path element %s is invalid' % element)
raise ValidationError("Path element %s is invalid" % element)
class Meta:
"""Always use strict checking so error handlers are invoked."""
@ -71,7 +72,7 @@ def is_valid_group_path(path: str) -> bool:
"""Return whether the group path is valid or not."""
schema = GroupSchema(partial=True)
try:
schema.validate({'path': path})
schema.validate({"path": path})
except ValidationError:
return False

58
tildes/tildes/schemas/topic.py

@ -4,13 +4,7 @@ import re
import typing
from urllib.parse import urlparse
from marshmallow import (
pre_load,
Schema,
validates,
validates_schema,
ValidationError,
)
from marshmallow import pre_load, Schema, validates, validates_schema, ValidationError
from marshmallow.fields import DateTime, List, Nested, String, URL
import sqlalchemy_utils
@ -29,7 +23,7 @@ class TopicSchema(Schema):
topic_type = Enum(dump_only=True)
markdown = Markdown(allow_none=True)
rendered_html = String(dump_only=True)
link = URL(schemes={'http', 'https'}, allow_none=True)
link = URL(schemes={"http", "https"}, allow_none=True)
created_time = DateTime(dump_only=True)
tags = List(Ltree())
@ -39,22 +33,22 @@ class TopicSchema(Schema):
@pre_load
def prepare_tags(self, data: dict) -> dict:
"""Prepare the tags before they're validated."""
if 'tags' not in data:
if "tags" not in data:
return data
tags: typing.List[str] = []
for tag in data['tags']:
for tag in data["tags"]:
tag = tag.lower()
# replace spaces with underscores
tag = tag.replace(' ', '_')
tag = tag.replace(" ", "_")
# remove any consecutive underscores
tag = re.sub('_{2,}', '_', tag)
tag = re.sub("_{2,}", "_", tag)
# remove any leading/trailing underscores
tag = tag.strip('_')
tag = tag.strip("_")
# drop any empty tags
if not tag or tag.isspace():
@ -66,15 +60,12 @@ class TopicSchema(Schema):
tags.append(tag)
data['tags'] = tags
data["tags"] = tags
return data
@validates('tags')
def validate_tags(
self,
value: typing.List[sqlalchemy_utils.Ltree],
) -> None:
@validates("tags")
def validate_tags(self, value: typing.List[sqlalchemy_utils.Ltree]) -> None:
"""Validate the tags field, raising an error if an issue exists.
Note that tags are validated by ensuring that each tag would be a valid
@ -86,52 +77,51 @@ class TopicSchema(Schema):
group_schema = GroupSchema(partial=True)
for tag in value:
try:
group_schema.validate({'path': tag})
group_schema.validate({"path": tag})
except ValidationError:
raise ValidationError('Tag %s is invalid' % tag)
raise ValidationError("Tag %s is invalid" % tag)
@pre_load
def prepare_markdown(self, data: dict) -> dict:
"""Prepare the markdown value before it's validated."""
if 'markdown' not in data:
if "markdown" not in data:
return data
# if the value is empty, convert it to None
if not data['markdown'] or data['markdown'].isspace():
data['markdown'] = None
if not data["markdown"] or data["markdown"].isspace():
data["markdown"] = None
return data
@pre_load
def prepare_link(self, data: dict) -> dict:
"""Prepare the link value before it's validated."""
if 'link' not in data:
if "link" not in data:
return data
# if the value is empty, convert it to None
if not data['link'] or data['link'].isspace():
data['link'] = None
if not data["link"] or data["link"].isspace():
data["link"] = None
return data
# prepend http:// to the link if it doesn't have a scheme
parsed = urlparse(data['link'])
parsed = urlparse(data["link"])
if not parsed.scheme:
data['link'] = 'http://' + data['link']
data["link"] = "http://" + data["link"]
return data
@validates_schema
def link_or_markdown(self, data: dict) -> None:
"""Fail validation unless at least one of link or markdown were set."""
if 'link' not in data and 'markdown' not in data:
if "link" not in data and "markdown" not in data:
return
link = data.get('link')
markdown = data.get('markdown')
link = data.get("link")
markdown = data.get("markdown")
if not (markdown or link):
raise ValidationError(
'Topics must have either markdown or a link.')
raise ValidationError("Topics must have either markdown or a link.")
class Meta:
"""Always use strict checking so error handlers are invoked."""

13
tildes/tildes/schemas/topic_listing.py

@ -17,25 +17,22 @@ class TopicListingSchema(Schema):
period = ShortTimePeriod(allow_none=True)
after = ID36()
before = ID36()
per_page = Integer(
validate=Range(min=1, max=100),
missing=DEFAULT_TOPICS_PER_PAGE,
)
rank_start = Integer(load_from='n', validate=Range(min=1), missing=None)
per_page = Integer(validate=Range(min=1, max=100), missing=DEFAULT_TOPICS_PER_PAGE)
rank_start = Integer(load_from="n", validate=Range(min=1), missing=None)
tag = Ltree(missing=None)
unfiltered = Boolean(missing=False)
@validates_schema
def either_after_or_before(self, data: dict) -> None:
"""Fail validation if both after and before were specified."""
if data.get('after') and data.get('before'):
if data.get("after") and data.get("before"):
raise ValidationError("Can't specify both after and before.")
@pre_load
def reset_rank_start_on_first_page(self, data: dict) -> dict:
"""Reset rank_start to 1 if this is a first page (no before/after)."""
if not (data.get('before') or data.get('after')):
data['rank_start'] = 1
if not (data.get("before") or data.get("after")):
data["rank_start"] = 1
return data

46
tildes/tildes/schemas/user.py

@ -2,13 +2,7 @@
import re
from marshmallow import (
post_dump,
pre_load,
Schema,
validates,
validates_schema,
)
from marshmallow import post_dump, pre_load, Schema, validates, validates_schema
from marshmallow.exceptions import ValidationError
from marshmallow.fields import Boolean, DateTime, Email, String
from marshmallow.validate import Length, Regexp
@ -26,11 +20,14 @@ USERNAME_MAX_LENGTH = 20
# more than one underscore/dash consecutively (this includes both "_-" and
# "-_" sequences being invalid)
# Note: this regex does not contain any length checks, must be done separately
# fmt: off
USERNAME_VALID_REGEX = re.compile(
"^[a-z0-9]" # start
"([a-z0-9]|[_-](?![_-]))*" # middle
"[a-z0-9]$", # end
re.IGNORECASE)
re.IGNORECASE,
)
# fmt: on
PASSWORD_MIN_LENGTH = 8
@ -48,29 +45,26 @@ class UserSchema(Schema):
required=True,
)
password = String(
validate=Length(min=PASSWORD_MIN_LENGTH),
required=True,
load_only=True,
validate=Length(min=PASSWORD_MIN_LENGTH), required=True, load_only=True
)
email_address = Email(allow_none=True, load_only=True)
email_address_note = String(
validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
email_address_note = String(validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
created_time = DateTime(dump_only=True)
track_comment_visits = Boolean()
@post_dump
def anonymize_username(self, data: dict) -> dict:
"""Hide the username if the dumping context specifies to do so."""
if 'username' in data and self.context.get('hide_username'):
data['username'] = '<unknown>'
if "username" in data and self.context.get("hide_username"):
data["username"] = "<unknown>"
return data
@validates_schema
def username_pass_not_substrings(self, data: dict) -> None:
"""Ensure the username isn't in the password and vice versa."""
username = data.get('username')
password = data.get('password')
username = data.get("username")
password = data.get("password")
if not (username and password):
return
@ -78,32 +72,32 @@ class UserSchema(Schema):
password = password.lower()
if username in password:
raise ValidationError('Password cannot contain username')
raise ValidationError("Password cannot contain username")
if password in username:
raise ValidationError('Username cannot contain password')
raise ValidationError("Username cannot contain password")
@validates('password')
@validates("password")
def password_not_breached(self, value: str) -> None:
"""Validate that the password is not in the breached-passwords list.
Requires check_breached_passwords be True in the schema's context.
"""
if not self.context.get('check_breached_passwords'):
if not self.context.get("check_breached_passwords"):
return
if is_breached_password(value):
raise ValidationError('That password exists in a data breach')
raise ValidationError("That password exists in a data breach")
@pre_load
def prepare_email_address(self, data: dict) -> dict:
"""Prepare the email address value before it's validated."""
if 'email_address' not in data:
if "email_address" not in data:
return data
# if the value is empty, convert it to None
if not data['email_address'] or data['email_address'].isspace():
data['email_address'] = None
if not data["email_address"] or data["email_address"].isspace():
data["email_address"] = None
return data
@ -122,7 +116,7 @@ def is_valid_username(username: str) -> bool:
"""
schema = UserSchema(partial=True)
try:
schema.validate({'username': username})
schema.validate({"username": username})
except ValidationError:
return False

2
tildes/tildes/views/__init__.py

@ -10,4 +10,4 @@ IC_NOOP_404 = Response(status_int=404)
# Because of the above, in order to deliberately cause Intercooler to replace
# an element with whitespace, the response needs to contain at least two spaces
IC_EMPTY = Response(' ')
IC_EMPTY = Response(" ")

2
tildes/tildes/views/api/v0/group.py

@ -6,7 +6,7 @@ from tildes.api import APIv0
from tildes.resources.group import group_by_path
ONE = APIv0(name='group', path='/groups/{group_path}', factory=group_by_path)
ONE = APIv0(name="group", path="/groups/{group_path}", factory=group_by_path)
@ONE.get()

4
tildes/tildes/views/api/v0/topic.py

@ -7,9 +7,7 @@ from tildes.resources.topic import topic_by_id36
ONE = APIv0(
name='topic',
path='/groups/{group_path}/topics/{topic_id36}',
factory=topic_by_id36,
name="topic", path="/groups/{group_path}/topics/{topic_id36}", factory=topic_by_id36
)

2
tildes/tildes/views/api/v0/user.py

@ -6,7 +6,7 @@ from tildes.api import APIv0
from tildes.resources.user import user_by_username
ONE = APIv0(name='user', path='/users/{username}', factory=user_by_username)
ONE = APIv0(name="user", path="/users/{username}", factory=user_by_username)
@ONE.get()

167
tildes/tildes/views/api/web/comment.py

@ -12,22 +12,14 @@ from zope.sqlalchemy import mark_changed
from tildes.enums import CommentNotificationType, CommentTagOption
from tildes.lib.datetime import utc_now
from tildes.models.comment import (
Comment,
CommentNotification,
CommentTag,
CommentVote,
)
from tildes.models.comment import Comment, CommentNotification, CommentTag, CommentVote
from tildes.models.topic import TopicVisit
from tildes.schemas.comment import CommentSchema, CommentTagSchema
from tildes.views import IC_NOOP
from tildes.views.decorators import ic_view_config
def _increment_topic_comments_seen(
request: Request,
comment: Comment,
) -> None:
def _increment_topic_comments_seen(request: Request, comment: Comment) -> None:
"""Increment the number of comments in a topic the user has viewed.
If the user has the "track comment visits" feature enabled, we want to
@ -50,7 +42,7 @@ def _increment_topic_comments_seen(
)
.on_conflict_do_update(
constraint=TopicVisit.__table__.primary_key,
set_={'num_comments': TopicVisit.num_comments + 1},
set_={"num_comments": TopicVisit.num_comments + 1},
where=TopicVisit.visit_time < comment.created_time,
)
)
@ -60,28 +52,22 @@ def _increment_topic_comments_seen(
@ic_view_config(
route_name='topic_comments',
request_method='POST',
renderer='single_comment.jinja2',
permission='comment',
route_name="topic_comments",
request_method="POST",
renderer="single_comment.jinja2",
permission="comment",
)
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def post_toplevel_comment(request: Request, markdown: str) -> dict:
"""Post a new top-level comment on a topic with Intercooler."""
topic = request.context
new_comment = Comment(
topic=topic,
author=request.user,
markdown=markdown,
)
new_comment = Comment(topic=topic, author=request.user, markdown=markdown)
request.db_session.add(new_comment)
if topic.user != request.user and not topic.is_deleted:
notification = CommentNotification(
topic.user,
new_comment,
CommentNotificationType.TOPIC_REPLY,
topic.user, new_comment, CommentNotificationType.TOPIC_REPLY
)
request.db_session.add(notification)
@ -95,16 +81,16 @@ def post_toplevel_comment(request: Request, markdown: str) -> dict:
.one()
)
return {'comment': new_comment, 'topic': topic}
return {"comment": new_comment, "topic": topic}
@ic_view_config(
route_name='comment_replies',
request_method='POST',
renderer='single_comment.jinja2',
permission='reply',
route_name="comment_replies",
request_method="POST",
renderer="single_comment.jinja2",
permission="reply",
)
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def post_comment_reply(request: Request, markdown: str) -> dict:
"""Post a reply to a comment with Intercooler."""
parent_comment = request.context
@ -118,9 +104,7 @@ def post_comment_reply(request: Request, markdown: str) -> dict:
if parent_comment.user != request.user:
notification = CommentNotification(
parent_comment.user,
new_comment,
CommentNotificationType.COMMENT_REPLY,
parent_comment.user, new_comment, CommentNotificationType.COMMENT_REPLY
)
request.db_session.add(notification)
@ -134,67 +118,67 @@ def post_comment_reply(request: Request, markdown: str) -> dict:
.one()
)
return {'comment': new_comment}
return {"comment": new_comment}
@ic_view_config(
route_name='comment',
request_method='GET',
renderer='comment_contents.jinja2',
permission='view',
route_name="comment",
request_method="GET",
renderer="comment_contents.jinja2",
permission="view",
)
def get_comment_contents(request: Request) -> dict:
"""Get a comment's body with Intercooler."""
return {'comment': request.context}
return {"comment": request.context}
@ic_view_config(
route_name='comment',
request_method='GET',
request_param='ic-trigger-name=edit',
renderer='comment_edit.jinja2',
permission='edit',
route_name="comment",
request_method="GET",
request_param="ic-trigger-name=edit",
renderer="comment_edit.jinja2",
permission="edit",
)
def get_comment_edit(request: Request) -> dict:
"""Get the edit form for a comment with Intercooler."""
return {'comment': request.context}
return {"comment": request.context}
@ic_view_config(
route_name='comment',
request_method='PATCH',
renderer='comment_contents.jinja2',
permission='edit',
route_name="comment",
request_method="PATCH",
renderer="comment_contents.jinja2",
permission="edit",
)
@use_kwargs(CommentSchema(only=('markdown',)))
@use_kwargs(CommentSchema(only=("markdown",)))
def patch_comment(request: Request, markdown: str) -> dict:
"""Update a comment with Intercooler."""
comment = request.context
comment.markdown = markdown
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment',
request_method='DELETE',
renderer='comment_contents.jinja2',
permission='delete',
route_name="comment",
request_method="DELETE",
renderer="comment_contents.jinja2",
permission="delete",
)
def delete_comment(request: Request) -> dict:
"""Delete a comment with Intercooler."""
comment = request.context
comment.is_deleted = True
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment_vote',
request_method='PUT',
permission='vote',
renderer='comment_contents.jinja2',
route_name="comment_vote",
request_method="PUT",
permission="vote",
renderer="comment_contents.jinja2",
)
def put_vote_comment(request: Request) -> dict:
"""Vote on a comment with Intercooler."""
@ -222,22 +206,21 @@ def put_vote_comment(request: Request) -> dict:
.one()
)
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment_vote',
request_method='DELETE',
permission='vote',
renderer='comment_contents.jinja2',
route_name="comment_vote",
request_method="DELETE",
permission="vote",
renderer="comment_contents.jinja2",
)
def delete_vote_comment(request: Request) -> dict:
"""Remove the user's vote from a comment with Intercooler."""
comment = request.context
request.query(CommentVote).filter(
CommentVote.comment == comment,
CommentVote.user == request.user,
CommentVote.comment == comment, CommentVote.user == request.user
).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute
@ -251,16 +234,16 @@ def delete_vote_comment(request: Request) -> dict:
.one()
)
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment_tag',
request_method='PUT',
permission='tag',
renderer='comment_contents.jinja2',
route_name="comment_tag",
request_method="PUT",
permission="tag",
renderer="comment_contents.jinja2",
)
@use_kwargs(CommentTagSchema(only=('name',)), locations=('matchdict',))
@use_kwargs(CommentTagSchema(only=("name",)), locations=("matchdict",))
def put_tag_comment(request: Request, name: CommentTagOption) -> Response:
"""Add a tag to a comment."""
comment = request.context
@ -286,16 +269,16 @@ def put_tag_comment(request: Request, name: CommentTagOption) -> Response:
.one()
)
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment_tag',
request_method='DELETE',
permission='tag',
renderer='comment_contents.jinja2',
route_name="comment_tag",
request_method="DELETE",
permission="tag",
renderer="comment_contents.jinja2",
)
@use_kwargs(CommentTagSchema(only=('name',)), locations=('matchdict',))
@use_kwargs(CommentTagSchema(only=("name",)), locations=("matchdict",))
def delete_tag_comment(request: Request, name: CommentTagOption) -> Response:
"""Remove a tag (that the user previously added) from a comment."""
comment = request.context
@ -316,19 +299,14 @@ def delete_tag_comment(request: Request, name: CommentTagOption) -> Response:
.one()
)
return {'comment': comment}
return {"comment": comment}
@ic_view_config(
route_name='comment_mark_read',
request_method='PUT',
permission='mark_read',
route_name="comment_mark_read", request_method="PUT", permission="mark_read"
)
@use_kwargs({'mark_all_previous': Boolean(missing=False)})
def put_mark_comments_read(
request: Request,
mark_all_previous: bool,
) -> Response:
@use_kwargs({"mark_all_previous": Boolean(missing=False)})
def put_mark_comments_read(request: Request, mark_all_previous: bool) -> Response:
"""Mark comment(s) read, clearing notifications.
The "main" notification (request.context) will always be marked read, and
@ -339,7 +317,8 @@ def put_mark_comments_read(
if mark_all_previous:
prev_notifications = (
request.query(CommentNotification).filter(
request.query(CommentNotification)
.filter(
CommentNotification.user == request.user,
CommentNotification.is_unread == True, # noqa
CommentNotification.created_time <= notification.created_time,
@ -351,16 +330,14 @@ def put_mark_comments_read(
# sort the notifications by created_time of their comment so that the
# INSERT ... ON CONFLICT DO UPDATE statements work as expected
prev_notifications = sorted(
prev_notifications, key=lambda c: c.comment.created_time)
prev_notifications, key=lambda c: c.comment.created_time
)
for comment_notification in prev_notifications:
comment_notification.is_unread = False
_increment_topic_comments_seen(
request,
comment_notification.comment
)
_increment_topic_comments_seen(request, comment_notification.comment)
return Response('Your comment notifications have been cleared.')
return Response("Your comment notifications have been cleared.")
notification.is_unread = False
_increment_topic_comments_seen(request, notification.comment)

21
tildes/tildes/views/api/web/exceptions.py

@ -18,7 +18,7 @@ from tildes.views.decorators import ic_view_config
def _422_response_with_errors(errors: Sequence[str]) -> Response:
response = Response('\n'.join(errors))
response = Response("\n".join(errors))
response.status_int = 422
return response
@ -44,9 +44,9 @@ def unprocessable_entity(request: Request) -> Response:
error_strings = []
for field, errors in errors_by_field.items():
joined_errors = ' '.join(errors)
if field != '_schema':
error_strings.append(f'{field}: {joined_errors}')
joined_errors = " ".join(errors)
if field != "_schema":
error_strings.append(f"{field}: {joined_errors}")
else:
error_strings.append(joined_errors)
@ -65,11 +65,11 @@ def httpnotfound(request: Request) -> Response:
response = request.exception
if request.matched_route.factory == comment_by_id36:
response.text = 'Comment not found (or it was deleted)'
response.text = "Comment not found (or it was deleted)"
elif request.matched_route.factory == topic_by_id36:
response.text = 'Topic not found (or it was deleted)'
response.text = "Topic not found (or it was deleted)"
else:
response.text = 'Not found'
response.text = "Not found"
return response
@ -79,10 +79,9 @@ def httptoomanyrequests(request: Request) -> Response:
"""Update a 429 error to show wait time info in the response text."""
response = request.exception
retry_seconds = request.exception.headers['Retry-After']
retry_seconds = request.exception.headers["Retry-After"]
response.text = (
'Rate limit exceeded. '
f'Please wait {retry_seconds} seconds before retrying.'
f"Rate limit exceeded. Please wait {retry_seconds} seconds before retrying."
)
return response
@ -99,4 +98,4 @@ def httpfound(request: Request) -> Response:
exception view will convert a 302 into a 200 with that header so it works
as a redirect for both standard requests as well as Intercooler ones.
"""
return Response(headers={'X-IC-Redirect': request.exception.location})
return Response(headers={"X-IC-Redirect": request.exception.location})

41
tildes/tildes/views/api/web/group.py

@ -17,10 +17,10 @@ from tildes.views.decorators import ic_view_config
@ic_view_config(
route_name='group_subscribe',
request_method='PUT',
permission='subscribe',
renderer='group_subscription_box.jinja2',
route_name="group_subscribe",
request_method="PUT",
permission="subscribe",
renderer="group_subscription_box.jinja2",
)
def put_subscribe_group(request: Request) -> dict:
"""Subscribe to a group with Intercooler."""
@ -48,22 +48,21 @@ def put_subscribe_group(request: Request) -> dict:
.one()
)
return {'group': group}
return {"group": group}
@ic_view_config(
route_name='group_subscribe',
request_method='DELETE',
permission='subscribe',
renderer='group_subscription_box.jinja2',
route_name="group_subscribe",
request_method="DELETE",
permission="subscribe",
renderer="group_subscription_box.jinja2",
)
def delete_subscribe_group(request: Request) -> dict:
"""Remove the user's subscription from a group with Intercooler."""
group = request.context
request.query(GroupSubscription).filter(
GroupSubscription.group == group,
GroupSubscription.user == request.user,
GroupSubscription.group == group, GroupSubscription.user == request.user
).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute
@ -77,27 +76,21 @@ def delete_subscribe_group(request: Request) -> dict:
.one()
)
return {'group': group}
return {"group": group}
@ic_view_config(
route_name='group_user_settings',
request_method='PATCH',
@ic_view_config(route_name="group_user_settings", request_method="PATCH")
@use_kwargs(
{"order": Enum(TopicSortOption), "period": ShortTimePeriod(allow_none=True)}
)
@use_kwargs({
'order': Enum(TopicSortOption),
'period': ShortTimePeriod(allow_none=True),
})
def patch_group_user_settings(
request: Request,
order: TopicSortOption,
period: Optional[ShortTimePeriod],
request: Request, order: TopicSortOption, period: Optional[ShortTimePeriod]
) -> dict:
"""Set the user's default listing options."""
if period:
default_period = period.as_short_form()
else:
default_period = 'all'
default_period = "all"
statement = (
insert(UserGroupSettings.__table__)
@ -109,7 +102,7 @@ def patch_group_user_settings(
)
.on_conflict_do_update(
constraint=UserGroupSettings.__table__.primary_key,
set_={'default_order': order, 'default_period': default_period},
set_={"default_order": order, "default_period": default_period},
)
)
request.db_session.execute(statement)

16
tildes/tildes/views/api/web/message.py

@ -9,19 +9,17 @@ from tildes.views.decorators import ic_view_config
@ic_view_config(
route_name='message_conversation_replies',
request_method='POST',
renderer='single_message.jinja2',
permission='reply',
route_name="message_conversation_replies",
request_method="POST",
renderer="single_message.jinja2",
permission="reply",
)
@use_kwargs(MessageReplySchema(only=('markdown',)))
@use_kwargs(MessageReplySchema(only=("markdown",)))
def post_message_reply(request: Request, markdown: str) -> dict:
"""Post a reply to a message conversation with Intercooler."""
conversation = request.context
new_reply = MessageReply(
conversation=conversation,
sender=request.user,
markdown=markdown,
conversation=conversation, sender=request.user, markdown=markdown
)
request.db_session.add(new_reply)
@ -35,4 +33,4 @@ def post_message_reply(request: Request, markdown: str) -> dict:
.one()
)
return {'message': new_reply}
return {"message": new_reply}

172
tildes/tildes/views/api/web/topic.py

@ -19,66 +19,63 @@ from tildes.views.decorators import ic_view_config
@ic_view_config(
route_name='topic',
request_method='GET',
request_param='ic-trigger-name=edit',
renderer='topic_edit.jinja2',
permission='edit',
route_name="topic",
request_method="GET",
request_param="ic-trigger-name=edit",
renderer="topic_edit.jinja2",
permission="edit",
)
def get_topic_edit(request: Request) -> dict:
"""Get the edit form for a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config(
route_name='topic',
request_method='GET',
renderer='topic_contents.jinja2',
permission='view',
route_name="topic",
request_method="GET",
renderer="topic_contents.jinja2",
permission="view",
)
def get_topic_contents(request: Request) -> dict:
"""Get a topic's body with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config(
route_name='topic',
request_method='PATCH',
renderer='topic_contents.jinja2',
permission='edit',
route_name="topic",
request_method="PATCH",
renderer="topic_contents.jinja2",
permission="edit",
)
@use_kwargs(TopicSchema(only=('markdown',)))
@use_kwargs(TopicSchema(only=("markdown",)))
def patch_topic(request: Request, markdown: str) -> dict:
"""Update a topic with Intercooler."""
topic = request.context
topic.markdown = markdown
return {'topic': topic}
return {"topic": topic}
@ic_view_config(
route_name='topic',
request_method='DELETE',
permission='delete',
)
@ic_view_config(route_name="topic", request_method="DELETE", permission="delete")
def delete_topic(request: Request) -> Response:
"""Delete a topic with Intercooler and redirect to its group."""
topic = request.context
topic.is_deleted = True
response = Response()
response.headers['X-IC-Redirect'] = request.route_url(
'group', group_path=topic.group.path)
response.headers["X-IC-Redirect"] = request.route_url(
"group", group_path=topic.group.path
)
return response
@ic_view_config(
route_name='topic_vote',
request_method='PUT',
renderer='topic_voting.jinja2',
permission='vote',
route_name="topic_vote",
request_method="PUT",
renderer="topic_voting.jinja2",
permission="vote",
)
def put_topic_vote(request: Request) -> Response:
"""Vote on a topic with Intercooler."""
@ -106,22 +103,21 @@ def put_topic_vote(request: Request) -> Response:
.one()
)
return {'topic': topic}
return {"topic": topic}
@ic_view_config(
route_name='topic_vote',
request_method='DELETE',
renderer='topic_voting.jinja2',
permission='vote',
route_name="topic_vote",
request_method="DELETE",
renderer="topic_voting.jinja2",
permission="vote",
)
def delete_topic_vote(request: Request) -> Response:
"""Remove the user's vote from a topic with Intercooler."""
topic = request.context
request.query(TopicVote).filter(
TopicVote.topic == topic,
TopicVote.user == request.user,
TopicVote.topic == topic, TopicVote.user == request.user
).delete(synchronize_session=False)
# manually commit the transaction so triggers will execute
@ -135,34 +131,34 @@ def delete_topic_vote(request: Request) -> Response:
.one()
)
return {'topic': topic}
return {"topic": topic}
@ic_view_config(
route_name='topic_tags',
request_method='GET',
renderer='topic_tags_edit.jinja2',
permission='tag',
route_name="topic_tags",
request_method="GET",
renderer="topic_tags_edit.jinja2",
permission="tag",
)
def get_topic_tags(request: Request) -> dict:
"""Get the tagging form for a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config(
route_name='topic_tags',
request_method='PUT',
renderer='topic_tags.jinja2',
permission='tag',
route_name="topic_tags",
request_method="PUT",
renderer="topic_tags.jinja2",
permission="tag",
)
@use_kwargs({'tags': String()})
@use_kwargs({"tags": String()})
def put_tag_topic(request: Request, tags: str) -> dict:
"""Apply tags to a topic with Intercooler."""
topic = request.context
if tags:
# split the tag string on commas
new_tags = tags.split(',')
new_tags = tags.split(",")
else:
new_tags = []
@ -171,7 +167,7 @@ def put_tag_topic(request: Request, tags: str) -> dict:
try:
topic.tags = new_tags
except ValidationError:
raise ValidationError({'tags': ['Invalid tags']})
raise ValidationError({"tags": ["Invalid tags"]})
# if tags weren't changed, don't add a log entry or update page
if set(topic.tags) == set(old_tags):
@ -182,42 +178,38 @@ def put_tag_topic(request: Request, tags: str) -> dict:
LogEventType.TOPIC_TAG,
request,
topic,
info={'old': old_tags, 'new': topic.tags},
),
info={"old": old_tags, "new": topic.tags},
)
)
return {'topic': topic}
return {"topic": topic}
@ic_view_config(
route_name='topic_group',
request_method='GET',
renderer='topic_group_edit.jinja2',
permission='move',
route_name="topic_group",
request_method="GET",
renderer="topic_group_edit.jinja2",
permission="move",
)
def get_topic_group(request: Request) -> dict:
"""Get the form for moving a topic with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config(
route_name='topic',
request_param='ic-trigger-name=topic-move',
request_method='PATCH',
permission='move',
route_name="topic",
request_param="ic-trigger-name=topic-move",
request_method="PATCH",
permission="move",
)
@use_kwargs(GroupSchema(only=('path',)))
@use_kwargs(GroupSchema(only=("path",)))
def patch_move_topic(request: Request, path: str) -> dict:
"""Move a topic to a different group with Intercooler."""
topic = request.context
new_group = (
request.query(Group)
.filter(Group.path == path)
.one_or_none()
)
new_group = request.query(Group).filter(Group.path == path).one_or_none()
if not new_group:
raise HTTPNotFound('Group not found')
raise HTTPNotFound("Group not found")
old_group = topic.group
@ -231,18 +223,14 @@ def patch_move_topic(request: Request, path: str) -> dict:
LogEventType.TOPIC_MOVE,
request,
topic,
info={'old': str(old_group.path), 'new': str(topic.group.path)}
),
info={"old": str(old_group.path), "new": str(topic.group.path)},
)
)
return Response('Moved')
return Response("Moved")
@ic_view_config(
route_name='topic_lock',
request_method='PUT',
permission='lock',
)
@ic_view_config(route_name="topic_lock", request_method="PUT", permission="lock")
def put_topic_lock(request: Request) -> Response:
"""Lock a topic with Intercooler."""
topic = request.context
@ -250,14 +238,10 @@ def put_topic_lock(request: Request) -> Response:
topic.is_locked = True
request.db_session.add(LogTopic(LogEventType.TOPIC_LOCK, request, topic))
return Response('Locked')
return Response("Locked")
@ic_view_config(
route_name='topic_lock',
request_method='DELETE',
permission='lock',
)
@ic_view_config(route_name="topic_lock", request_method="DELETE", permission="lock")
def delete_topic_lock(request: Request) -> Response:
"""Unlock a topic with Intercooler."""
topic = request.context
@ -265,27 +249,27 @@ def delete_topic_lock(request: Request) -> Response:
topic.is_locked = False
request.db_session.add(LogTopic(LogEventType.TOPIC_UNLOCK, request, topic))
return Response('Unlocked')
return Response("Unlocked")
@ic_view_config(
route_name='topic_title',
request_method='GET',
renderer='topic_title_edit.jinja2',
permission='edit_title',
route_name="topic_title",
request_method="GET",
renderer="topic_title_edit.jinja2",
permission="edit_title",
)
def get_topic_title(request: Request) -> dict:
"""Get the form for editing a topic's title with Intercooler."""
return {'topic': request.context}
return {"topic": request.context}
@ic_view_config(
route_name='topic',
request_param='ic-trigger-name=topic-title-edit',
request_method='PATCH',
permission='edit_title',
route_name="topic",
request_param="ic-trigger-name=topic-title-edit",
request_method="PATCH",
permission="edit_title",
)
@use_kwargs(TopicSchema(only=('title',)))
@use_kwargs(TopicSchema(only=("title",)))
def patch_topic_title(request: Request, title: str) -> dict:
"""Edit a topic's title with Intercooler."""
topic = request.context
@ -298,8 +282,8 @@ def patch_topic_title(request: Request, title: str) -> dict:
LogEventType.TOPIC_TITLE_EDIT,
request,
topic,
info={'old': topic.title, 'new': title}
),
info={"old": topic.title, "new": title},
)
)
topic.title = title

139
tildes/tildes/views/api/web/user.py

@ -20,52 +20,48 @@ from tildes.views import IC_NOOP
from tildes.views.decorators import ic_view_config
PASSWORD_FIELD = UserSchema(only=('password',)).fields['password']
PASSWORD_FIELD = UserSchema(only=("password",)).fields["password"]
@ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=password-change',
permission='change_password',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=password-change",
permission="change_password",
)
@use_kwargs(
{
"old_password": PASSWORD_FIELD,
"new_password": PASSWORD_FIELD,
"new_password_confirm": PASSWORD_FIELD,
}
)
@use_kwargs({
'old_password': PASSWORD_FIELD,
'new_password': PASSWORD_FIELD,
'new_password_confirm': PASSWORD_FIELD,
})
def patch_change_password(
request: Request,
old_password: str,
new_password: str,
new_password_confirm: str,
request: Request, old_password: str, new_password: str, new_password_confirm: str
) -> Response:
"""Change the logged-in user's password."""
user = request.context
# enable checking the new password against the breached-passwords list
user.schema.context['check_breached_passwords'] = True
user.schema.context["check_breached_passwords"] = True
if new_password != new_password_confirm:
raise HTTPUnprocessableEntity(
'New password and confirmation do not match.')
raise HTTPUnprocessableEntity("New password and confirmation do not match.")
user.change_password(old_password, new_password)
return Response('Your password has been updated')
return Response("Your password has been updated")
@ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=account-recovery-email',
permission='change_email_address',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=account-recovery-email",
permission="change_email_address",
)
@use_kwargs(UserSchema(only=('email_address', 'email_address_note')))
@use_kwargs(UserSchema(only=("email_address", "email_address_note")))
def patch_change_email_address(
request: Request,
email_address: str,
email_address_note: str
request: Request, email_address: str, email_address_note: str
) -> Response:
"""Change the user's email address (and descriptive note)."""
user = request.context
@ -77,46 +73,46 @@ def patch_change_email_address(
log_info = None
if user.email_address_hash:
log_info = {
'old_hash': user.email_address_hash,
'old_note': user.email_address_note,
"old_hash": user.email_address_hash,
"old_note": user.email_address_note,
}
request.db_session.add(Log(LogEventType.USER_EMAIL_SET, request, log_info))
user.email_address = email_address
user.email_address_note = email_address_note
return Response('Your email address has been updated')
return Response("Your email address has been updated")
@ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=auto-mark-notifications-read',
permission='change_auto_mark_notifications_read_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=auto-mark-notifications-read",
permission="change_auto_mark_notifications_read_setting",
)
def patch_change_auto_mark_notifications(request: Request) -> Response:
"""Change the user's "automatically mark notifications read" setting."""
user = request.context
auto_mark = bool(request.params.get('auto_mark_notifications_read'))
auto_mark = bool(request.params.get("auto_mark_notifications_read"))
user.auto_mark_notifications_read = auto_mark
return IC_NOOP
@ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=open-links-new-tab',
permission='change_open_links_new_tab_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=open-links-new-tab",
permission="change_open_links_new_tab_setting",
)
def patch_change_open_links_new_tab(request: Request) -> Response:
"""Change the user's "open links in new tabs" setting."""
user = request.context
external = bool(request.params.get('open_new_tab_external'))
internal = bool(request.params.get('open_new_tab_internal'))
text = bool(request.params.get('open_new_tab_text'))
external = bool(request.params.get("open_new_tab_external"))
internal = bool(request.params.get("open_new_tab_internal"))
text = bool(request.params.get("open_new_tab_text"))
user.open_new_tab_external = external
user.open_new_tab_internal = internal
user.open_new_tab_text = text
@ -125,16 +121,16 @@ def patch_change_open_links_new_tab(request: Request) -> Response:
@ic_view_config(
route_name='user',
request_method='PATCH',
request_param='ic-trigger-name=comment-visits',
permission='change_comment_visits_setting',
route_name="user",
request_method="PATCH",
request_param="ic-trigger-name=comment-visits",
permission="change_comment_visits_setting",
)
def patch_change_track_comment_visits(request: Request) -> Response:
"""Change the user's "track comment visits" setting."""
user = request.context
track_comment_visits = bool(request.params.get('track_comment_visits'))
track_comment_visits = bool(request.params.get("track_comment_visits"))
user.track_comment_visits = track_comment_visits
if track_comment_visits:
@ -144,20 +140,20 @@ def patch_change_track_comment_visits(request: Request) -> Response:
@ic_view_config(
route_name='user_invite_code',
request_method='GET',
permission='view_invite_code',
renderer='invite_code.jinja2',
route_name="user_invite_code",
request_method="GET",
permission="view_invite_code",
renderer="invite_code.jinja2",
)
def get_invite_code(request: Request) -> dict:
"""Generate a new invite code owned by the user."""
user = request.context
if request.user.invite_codes_remaining < 1:
raise HTTPForbidden('No invite codes remaining')
raise HTTPForbidden("No invite codes remaining")
# obtain a lock to prevent concurrent requests generating multiple codes
request.obtain_lock('generate_invite_code', user.user_id)
request.obtain_lock("generate_invite_code", user.user_id)
# it's possible to randomly generate an existing code, so we'll retry
# until we create a new one (will practically always be the first try)
@ -179,22 +175,19 @@ def get_invite_code(request: Request) -> dict:
num_remaining = request.user.invite_codes_remaining - 1
request.user.invite_codes_remaining = User.invite_codes_remaining - 1
return {'code': code, 'num_remaining': num_remaining}
return {"code": code, "num_remaining": num_remaining}
@ic_view_config(
route_name='user_default_listing_options',
request_method='PUT',
permission='edit_default_listing_options',
route_name="user_default_listing_options",
request_method="PUT",
permission="edit_default_listing_options",
)
@use_kwargs(
{"order": Enum(TopicSortOption), "period": ShortTimePeriod(allow_none=True)}
)
@use_kwargs({
'order': Enum(TopicSortOption),
'period': ShortTimePeriod(allow_none=True),
})
def put_default_listing_options(
request: Request,
order: TopicSortOption,
period: Optional[ShortTimePeriod],
request: Request, order: TopicSortOption, period: Optional[ShortTimePeriod]
) -> dict:
"""Set the user's default listing options."""
user = request.context
@ -203,31 +196,31 @@ def put_default_listing_options(
if period:
user.home_default_period = period.as_short_form()
else:
user.home_default_period = 'all'
user.home_default_period = "all"
return IC_NOOP
@ic_view_config(
route_name='user_filtered_topic_tags',
request_method='PUT',
permission='edit_filtered_topic_tags',
route_name="user_filtered_topic_tags",
request_method="PUT",
permission="edit_filtered_topic_tags",
)
@use_kwargs({'tags': String()})
@use_kwargs({"tags": String()})
def put_filtered_topic_tags(request: Request, tags: str) -> dict:
"""Update a user's filtered topic tags list."""
if not tags:
request.user.filtered_topic_tags = []
return IC_NOOP
split_tags = tags.split(',')
split_tags = tags.split(",")
try:
schema = TopicSchema(only=('tags',))
result = schema.load({'tags': split_tags})
schema = TopicSchema(only=("tags",))
result = schema.load({"tags": split_tags})
except ValidationError:
raise ValidationError({'tags': ['Invalid tags']})
raise ValidationError({"tags": ["Invalid tags"]})
request.user.filtered_topic_tags = result.data['tags']
request.user.filtered_topic_tags = result.data["tags"]
return IC_NOOP

16
tildes/tildes/views/decorators.py

@ -9,15 +9,15 @@ from pyramid.view import view_config
def ic_view_config(**kwargs: Any) -> Callable:
"""Wrap the @view_config decorator for Intercooler views."""
if 'route_name' in kwargs:
kwargs['route_name'] = 'ic_' + kwargs['route_name']
if "route_name" in kwargs:
kwargs["route_name"] = "ic_" + kwargs["route_name"]
if 'renderer' in kwargs:
kwargs['renderer'] = 'intercooler/' + kwargs['renderer']
if "renderer" in kwargs:
kwargs["renderer"] = "intercooler/" + kwargs["renderer"]
if 'header' in kwargs:
if "header" in kwargs:
raise ValueError("Can't add a header check to Intercooler view.")
kwargs['header'] = 'X-IC-Request:true'
kwargs["header"] = "X-IC-Request:true"
return view_config(**kwargs)
@ -32,6 +32,7 @@ def rate_limit_view(action_name: str) -> Callable:
response with appropriate headers will be raised instead of calling the
decorated view.
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Any:
request = args[0]
@ -55,9 +56,10 @@ def not_logged_in(func: Callable) -> Callable:
such as the login page, registration page, etc. which only logged-out users
should be accessing.
"""
def wrapper(request: Request, **kwargs: Any) -> Any:
if request.user:
raise HTTPFound(location=request.route_url('home'))
raise HTTPFound(location=request.route_url("home"))
return func(request, **kwargs)

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save