Browse Source

Merge branch 'master' into feature-saved-themes

merge-requests/25/head
Celeo 8 years ago
parent
commit
bd2f3f604b
  1. 1
      git_hooks/pre-commit
  2. 3
      git_hooks/pre-push
  3. 4
      salt/salt/boussole.service.jinja2
  4. 20
      salt/salt/boussole.sls
  5. 22
      tildes/alembic/env.py
  6. 26
      tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py
  7. 13
      tildes/alembic/versions/347859b0355e_added_account_default_theme_setting.py
  8. 13
      tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py
  9. 35
      tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py
  10. 124
      tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py
  11. 21
      tildes/consumers/comment_user_mentions_generator.py
  12. 29
      tildes/consumers/topic_metadata_generator.py
  13. 5
      tildes/gunicorn_config.py
  14. 17
      tildes/pylama.ini
  15. 5
      tildes/requirements-to-freeze.txt
  16. 29
      tildes/requirements.txt
  17. 104
      tildes/scripts/breached_passwords.py
  18. 67
      tildes/scripts/clean_private_data.py
  19. 54
      tildes/scripts/initialize_db.py
  20. 6
      tildes/setup.py
  21. 107
      tildes/tests/conftest.py
  22. 8
      tildes/tests/fixtures.py
  23. 64
      tildes/tests/test_comment.py
  24. 76
      tildes/tests/test_comment_user_mentions.py
  25. 14
      tildes/tests/test_datetime.py
  26. 41
      tildes/tests/test_group.py
  27. 10
      tildes/tests/test_id.py
  28. 167
      tildes/tests/test_markdown.py
  29. 22
      tildes/tests/test_markdown_field.py
  30. 36
      tildes/tests/test_messages.py
  31. 2
      tildes/tests/test_metrics.py
  32. 40
      tildes/tests/test_ratelimit.py
  33. 26
      tildes/tests/test_simplestring_field.py
  34. 68
      tildes/tests/test_string.py
  35. 36
      tildes/tests/test_title.py
  36. 44
      tildes/tests/test_topic.py
  37. 37
      tildes/tests/test_topic_permissions.py
  38. 22
      tildes/tests/test_topic_tags.py
  39. 6
      tildes/tests/test_triggers_comments.py
  40. 22
      tildes/tests/test_url.py
  41. 51
      tildes/tests/test_user.py
  42. 16
      tildes/tests/test_username.py
  43. 10
      tildes/tests/test_webassets.py
  44. 12
      tildes/tests/webtests/test_user_page.py
  45. 101
      tildes/tildes/__init__.py
  46. 10
      tildes/tildes/api.py
  47. 70
      tildes/tildes/auth.py
  48. 47
      tildes/tildes/database.py
  49. 26
      tildes/tildes/enums.py
  50. 26
      tildes/tildes/jinja.py
  51. 10
      tildes/tildes/json.py
  52. 9
      tildes/tildes/lib/__init__.py
  53. 33
      tildes/tildes/lib/amqp.py
  54. 12
      tildes/tildes/lib/cmark.py
  55. 36
      tildes/tildes/lib/database.py
  56. 52
      tildes/tildes/lib/datetime.py
  57. 11
      tildes/tildes/lib/hash.py
  58. 19
      tildes/tildes/lib/id.py
  59. 303
      tildes/tildes/lib/markdown.py
  60. 2
      tildes/tildes/lib/message.py
  61. 15
      tildes/tildes/lib/password.py
  62. 144
      tildes/tildes/lib/ratelimit.py
  63. 115
      tildes/tildes/lib/string.py
  64. 4
      tildes/tildes/lib/url.py
  65. 64
      tildes/tildes/metrics.py
  66. 125
      tildes/tildes/models/comment/comment.py
  67. 94
      tildes/tildes/models/comment/comment_notification.py
  68. 10
      tildes/tildes/models/comment/comment_notification_query.py
  69. 11
      tildes/tildes/models/comment/comment_query.py
  70. 29
      tildes/tildes/models/comment/comment_tag.py
  71. 46
      tildes/tildes/models/comment/comment_tree.py
  72. 24
      tildes/tildes/models/comment/comment_vote.py
  73. 71
      tildes/tildes/models/database_model.py
  74. 54
      tildes/tildes/models/group/group.py
  75. 11
      tildes/tildes/models/group/group_query.py
  76. 24
      tildes/tildes/models/group/group_subscription.py
  77. 140
      tildes/tildes/models/log/log.py
  78. 136
      tildes/tildes/models/message/message.py
  79. 75
      tildes/tildes/models/model_query.py
  80. 80
      tildes/tildes/models/pagination.py
  81. 196
      tildes/tildes/models/topic/topic.py
  82. 65
      tildes/tildes/models/topic/topic_query.py
  83. 40
      tildes/tildes/models/topic/topic_visit.py
  84. 24
      tildes/tildes/models/topic/topic_vote.py
  85. 87
      tildes/tildes/models/user/user.py
  86. 16
      tildes/tildes/models/user/user_group_settings.py
  87. 45
      tildes/tildes/models/user/user_invite_code.py
  88. 12
      tildes/tildes/resources/__init__.py
  89. 20
      tildes/tildes/resources/comment.py
  90. 19
      tildes/tildes/resources/group.py
  91. 11
      tildes/tildes/resources/message.py
  92. 13
      tildes/tildes/resources/topic.py
  93. 5
      tildes/tildes/resources/user.py
  94. 166
      tildes/tildes/routes.py
  95. 16
      tildes/tildes/schemas/__init__.py
  96. 49
      tildes/tildes/schemas/fields.py
  97. 27
      tildes/tildes/schemas/group.py
  98. 68
      tildes/tildes/schemas/topic.py
  99. 13
      tildes/tildes/schemas/topic_listing.py
  100. 58
      tildes/tildes/schemas/user.py

1
git_hooks/pre-commit

@ -4,4 +4,5 @@
vagrant ssh -c ". activate \
&& echo 'Checking mypy type annotations...' && mypy . \
&& echo 'Checking if Black would reformat any code...' && black --check . \
&& echo -n 'Running tests: ' && pytest -q"

3
git_hooks/pre-push

@ -4,5 +4,6 @@
vagrant ssh -c ". activate \
&& echo 'Checking mypy type annotations...' && mypy . \
&& echo 'Checking if Black would reformat any code...' && black --check . \
&& echo -n 'Running tests: ' && pytest -q \
&& echo 'Checking code style (takes a while)...' && pylama"
&& echo 'Checking code style fully (takes a while)...' && pylama"

4
salt/salt/boussole.service.jinja2

@ -1,11 +1,11 @@
{% from 'common.jinja2' import app_dir, bin_dir -%}
{% from 'common.jinja2' import app_dir -%}
[Unit]
Description=Boussole - auto-compile SCSS files on change
[Service]
WorkingDirectory={{ app_dir }}
Environment="LC_ALL=C.UTF-8" "LANG=C.UTF-8"
ExecStart={{ bin_dir }}/boussole watch --backend=yaml --config=boussole.yaml --poll
ExecStart=/opt/venvs/boussole/bin/boussole watch --backend=yaml --config=boussole.yaml --poll
Restart=always
RestartSec=5

20
salt/salt/boussole.sls

@ -1,4 +1,20 @@
{% from 'common.jinja2' import app_dir, bin_dir %}
{% from 'common.jinja2' import app_dir, python_version %}
{% set boussole_venv_dir = '/opt/venvs/boussole' %}
# Salt seems to use the deprecated pyvenv script, manual for now
boussole-venv-setup:
cmd.run:
- name: /usr/local/pyenv/versions/{{ python_version }}/bin/python -m venv {{ boussole_venv_dir }}
- creates: {{ boussole_venv_dir }}
- require:
- pkg: python3-venv
- pyenv: {{ python_version }}
boussole-pip-installs:
cmd.run:
- name: {{ boussole_venv_dir }}/bin/pip install boussole
- unless: ls {{ boussole_venv_dir }}/lib/python3.6/site-packages/boussole
/etc/systemd/system/boussole.service:
file.managed:
@ -22,7 +38,7 @@ create-css-directory:
initial-boussole-run:
cmd.run:
- name: {{ bin_dir }}/boussole compile --backend=yaml --config=boussole.yaml
- name: {{ boussole_venv_dir }}/bin/boussole compile --backend=yaml --config=boussole.yaml
- cwd: {{ app_dir }}
- env:
- LC_ALL: C.UTF-8

22
tildes/alembic/env.py

@ -12,12 +12,7 @@ config = context.config
fileConfig(config.config_file_name)
# import all DatabaseModel subclasses here for autogenerate support
from tildes.models.comment import (
Comment,
CommentNotification,
CommentTag,
CommentVote,
)
from tildes.models.comment import Comment, CommentNotification, CommentTag, CommentVote
from tildes.models.group import Group, GroupSubscription
from tildes.models.log import Log
from tildes.models.message import MessageConversation, MessageReply
@ -25,6 +20,7 @@ from tildes.models.topic import Topic, TopicVisit, TopicVote
from tildes.models.user import User, UserGroupSettings, UserInviteCode
from tildes.models import DatabaseModel
target_metadata = DatabaseModel.metadata
# other values from the config, defined by the needs of env.py,
@ -46,8 +42,7 @@ def run_migrations_offline():
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True)
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction():
context.run_migrations()
@ -62,18 +57,17 @@ def run_migrations_online():
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:

26
tildes/alembic/versions/2512581c91b3_add_setting_to_open_links_in_new_tab.py

@ -10,17 +10,33 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2512581c91b3'
revision = "2512581c91b3"
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
op.add_column('users', sa.Column('open_new_tab_external', sa.Boolean(), server_default='false', nullable=False))
op.add_column('users', sa.Column('open_new_tab_internal', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_external",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
op.add_column(
"users",
sa.Column(
"open_new_tab_internal",
sa.Boolean(),
server_default="false",
nullable=False,
),
)
def downgrade():
op.drop_column('users', 'open_new_tab_internal')
op.drop_column('users', 'open_new_tab_external')
op.drop_column("users", "open_new_tab_internal")
op.drop_column("users", "open_new_tab_external")

13
tildes/alembic/versions/347859b0355e_added_account_default_theme_setting.py

@ -10,15 +10,20 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '347859b0355e'
down_revision = 'fab922a8bb04'
revision = "347859b0355e"
down_revision = "fab922a8bb04"
branch_labels = None
depends_on = None
def upgrade():
op.add_column('users', sa.Column('theme_account_default', sa.Text(), server_default='', nullable=False))
op.add_column(
"users",
sa.Column(
"theme_account_default", sa.Text(), server_default="", nullable=False
),
)
def downgrade():
op.drop_column('users', 'theme_account_default')
op.drop_column("users", "theme_account_default")

13
tildes/alembic/versions/de83b8750123_add_setting_to_open_text_links_in_new_.py

@ -10,15 +10,20 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'de83b8750123'
down_revision = '2512581c91b3'
revision = "de83b8750123"
down_revision = "2512581c91b3"
branch_labels = None
depends_on = None
def upgrade():
op.add_column('users', sa.Column('open_new_tab_text', sa.Boolean(), server_default='false', nullable=False))
op.add_column(
"users",
sa.Column(
"open_new_tab_text", sa.Boolean(), server_default="false", nullable=False
),
)
def downgrade():
op.drop_column('users', 'open_new_tab_text')
op.drop_column("users", "open_new_tab_text")

35
tildes/alembic/versions/f1ecbf24c212_added_user_tag_type_comment_notification.py

@ -9,8 +9,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = 'f1ecbf24c212'
down_revision = 'de83b8750123'
revision = "f1ecbf24c212"
down_revision = "de83b8750123"
branch_labels = None
depends_on = None
@ -20,18 +20,18 @@ def upgrade():
connection = None
if not op.get_context().as_sql:
connection = op.get_bind()
connection.execution_options(isolation_level='AUTOCOMMIT')
connection.execution_options(isolation_level="AUTOCOMMIT")
op.execute(
"ALTER TYPE commentnotificationtype "
"ADD VALUE IF NOT EXISTS 'USER_MENTION'"
"ALTER TYPE commentnotificationtype ADD VALUE IF NOT EXISTS 'USER_MENTION'"
)
# re-activate the transaction for any future migrations
if connection is not None:
connection.execution_options(isolation_level='READ_COMMITTED')
connection.execution_options(isolation_level="READ_COMMITTED")
op.execute('''
op.execute(
"""
CREATE OR REPLACE FUNCTION send_rabbitmq_message_for_comment() RETURNS TRIGGER AS $$
DECLARE
affected_row RECORD;
@ -50,23 +50,28 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_insert
AFTER INSERT ON comments
FOR EACH ROW
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('created');
''')
op.execute('''
"""
)
op.execute(
"""
CREATE TRIGGER send_rabbitmq_message_for_comment_edit
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.markdown IS DISTINCT FROM NEW.markdown)
EXECUTE PROCEDURE send_rabbitmq_message_for_comment('edited');
''')
"""
)
def downgrade():
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments')
op.execute('DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments')
op.execute('DROP FUNCTION send_rabbitmq_message_for_comment')
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_insert ON comments")
op.execute("DROP TRIGGER send_rabbitmq_message_for_comment_edit ON comments")
op.execute("DROP FUNCTION send_rabbitmq_message_for_comment")

124
tildes/alembic/versions/fab922a8bb04_update_comment_triggers_for_removals.py

@ -10,8 +10,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fab922a8bb04'
down_revision = 'f1ecbf24c212'
revision = "fab922a8bb04"
down_revision = "f1ecbf24c212"
branch_labels = None
depends_on = None
@ -19,17 +19,20 @@ depends_on = None
def upgrade():
# comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted = false AND NEW.is_deleted = true)
OR (OLD.is_removed = false AND NEW.is_removed = true))
EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN
IF (NEW.is_deleted = TRUE) THEN
@ -41,17 +44,21 @@ def upgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_removed_time() RETURNS TRIGGER AS $$
BEGIN
IF (NEW.is_removed = TRUE) THEN
@ -63,19 +70,23 @@ def upgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER remove_comment_set_removed_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_removed IS DISTINCT FROM NEW.is_removed)
EXECUTE PROCEDURE set_comment_removed_time();
""")
"""
)
# topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION decrement_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
DECLARE
old_visible BOOLEAN := NOT (OLD.is_deleted OR OLD.is_removed);
@ -96,18 +107,22 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_all_topic_visit_num_comments();
""")
"""
)
# topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN
IF (TG_OP = 'INSERT') THEN
@ -140,18 +155,22 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE
most_recent_comment RECORD;
@ -182,31 +201,37 @@ def upgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN ((OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
OR (OLD.is_removed IS DISTINCT FROM NEW.is_removed))
EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)
def downgrade():
# comment_notifications
op.execute("DROP TRIGGER delete_comment_notifications_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_notifications_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE delete_comment_notifications();
""")
"""
)
# comments
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION set_comment_deleted_time() RETURNS TRIGGER AS $$
BEGIN
NEW.deleted_time := current_timestamp;
@ -214,15 +239,18 @@ def downgrade():
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER delete_comment_set_deleted_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER delete_comment_set_deleted_time_update
BEFORE UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE set_comment_deleted_time();
""")
"""
)
op.execute("DROP TRIGGER remove_comment_set_removed_time_update ON comments")
op.execute("DROP FUNCTION set_comment_removed_time()")
@ -230,7 +258,8 @@ def downgrade():
# topic_visits
op.execute("DROP TRIGGER update_topic_visits_num_comments_update ON comments")
op.execute("DROP FUNCTION update_all_topic_visit_num_comments()")
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION decrement_all_topic_visit_num_comments() RETURNS TRIGGER AS $$
BEGIN
UPDATE topic_visits
@ -241,17 +270,21 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
"""
)
op.execute(
"""
CREATE TRIGGER update_topic_visits_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted = false AND NEW.is_deleted = true)
EXECUTE PROCEDURE decrement_all_topic_visit_num_comments();
""")
"""
)
# topics
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_num_comments() RETURNS TRIGGER AS $$
BEGIN
IF (TG_OP = 'INSERT' AND NEW.is_deleted = FALSE) THEN
@ -277,17 +310,21 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_num_comments_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_num_comments_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_num_comments();
""")
"""
)
op.execute("""
op.execute(
"""
CREATE OR REPLACE FUNCTION update_topics_last_activity_time() RETURNS TRIGGER AS $$
DECLARE
most_recent_comment RECORD;
@ -317,12 +354,15 @@ def downgrade():
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
""")
"""
)
op.execute("DROP TRIGGER update_topics_last_activity_time_update ON comments")
op.execute("""
op.execute(
"""
CREATE TRIGGER update_topics_last_activity_time_update
AFTER UPDATE ON comments
FOR EACH ROW
WHEN (OLD.is_deleted IS DISTINCT FROM NEW.is_deleted)
EXECUTE PROCEDURE update_topics_last_activity_time();
""")
"""
)

21
tildes/consumers/comment_user_mentions_generator.py

@ -13,7 +13,7 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
"""Process a delivered message."""
comment = (
self.db_session.query(Comment)
.filter_by(comment_id=msg.body['comment_id'])
.filter_by(comment_id=msg.body["comment_id"])
.one()
)
@ -22,15 +22,16 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
return
new_mentions = CommentNotification.get_mentions_for_comment(
self.db_session, comment)
self.db_session, comment
)
if msg.delivery_info['routing_key'] == 'comment.created':
if msg.delivery_info["routing_key"] == "comment.created":
for user_mention in new_mentions:
self.db_session.add(user_mention)
elif msg.delivery_info['routing_key'] == 'comment.edited':
to_delete, to_add = (
CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions))
elif msg.delivery_info["routing_key"] == "comment.edited":
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
self.db_session, comment, new_mentions
)
for user_mention in to_delete:
self.db_session.delete(user_mention)
@ -39,8 +40,8 @@ class CommentUserMentionGenerator(PgsqlQueueConsumer):
self.db_session.add(user_mention)
if __name__ == '__main__':
if __name__ == "__main__":
CommentUserMentionGenerator(
queue_name='comment_user_mentions_generator.q',
routing_keys=['comment.created', 'comment.edited'],
queue_name="comment_user_mentions_generator.q",
routing_keys=["comment.created", "comment.edited"],
).consume_queue()

29
tildes/consumers/topic_metadata_generator.py

@ -26,9 +26,7 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
def run(self, msg: Message) -> None:
"""Process a delivered message."""
topic = (
self.db_session.query(Topic)
.filter_by(topic_id=msg.body['topic_id'])
.one()
self.db_session.query(Topic).filter_by(topic_id=msg.body["topic_id"]).one()
)
if topic.is_text_type:
@ -42,22 +40,19 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
html_tree = HTMLParser().parseFragment(topic.rendered_html)
# extract the text from all of the HTML elements
extracted_text = ''.join(
[element_text for element_text in html_tree.itertext()])
extracted_text = "".join(
[element_text for element_text in html_tree.itertext()]
)
# sanitize unicode, remove leading/trailing whitespace, etc.
extracted_text = simplify_string(extracted_text)
# create a short excerpt by truncating the simplified string
excerpt = truncate_string(
extracted_text,
length=200,
truncate_at_chars=' ',
)
excerpt = truncate_string(extracted_text, length=200, truncate_at_chars=" ")
topic.content_metadata = {
'word_count': word_count(extracted_text),
'excerpt': excerpt,
"word_count": word_count(extracted_text),
"excerpt": excerpt,
}
def _generate_link_metadata(self, topic: Topic) -> None:
@ -68,13 +63,11 @@ class TopicMetadataGenerator(PgsqlQueueConsumer):
parsed_domain = get_domain_from_url(topic.link)
domain = self.public_suffix_list.get_public_suffix(parsed_domain)
topic.content_metadata = {
'domain': domain,
}
topic.content_metadata = {"domain": domain}
if __name__ == '__main__':
if __name__ == "__main__":
TopicMetadataGenerator(
queue_name='topic_metadata_generator.q',
routing_keys=['topic.created', 'topic.edited'],
queue_name="topic_metadata_generator.q",
routing_keys=["topic.created", "topic.edited"],
).consume_queue()

5
tildes/gunicorn_config.py

@ -6,9 +6,8 @@ from prometheus_client import multiprocess
def child_exit(server, worker): # type: ignore
"""Mark worker processes as dead for Prometheus when the worker exits.
Note that this uses the child_exit hook instead of worker_exit so that
it's handled by the master process (and will still be called if a worker
crashes).
Note that this uses the child_exit hook instead of worker_exit so that it's handled
by the master process (and will still be called if a worker crashes).
"""
# pylint: disable=unused-argument
multiprocess.mark_process_dead(worker.pid)

17
tildes/pylama.ini

@ -3,6 +3,9 @@ linters = mccabe,pycodestyle,pydocstyle,pyflakes,pylint
skip = alembic/*
# ignored checks:
# - D202 - pydocstyle check for blank lines after a function docstring, but
# Black will add one when the first code in the function is another
# function definition.
# - D203 - pydocstyle has two mutually exclusive checks (D203/D211)
# for whether a class docstring should have a blank line before
# it or not. I don't want a blank line, so D203 is disabled.
@ -10,12 +13,18 @@ skip = alembic/*
# time for whether a multi-line docstring's summary line should be
# on the first or second line. I want it to be on the first line,
# so D213 needs to be disabled.
ignore = D203,D213
# - E203 - checks for whitespace around : in slices, but Black adds it
# in some cases.
ignore = D202,D203,D213,E203
[pylama:pycodestyle]
max_line_length = 88
[pylama:pylint]
enable = all
# disabled pylint checks:
# - bad-continuation (Black will handle wrapping lines properly)
# - missing-docstring (already reported by pydocstyle)
# - too-few-public-methods (more annoying than helpful, especially early on)
# - too-many-instance-attributes (overly-picky when models need many)
@ -23,6 +32,7 @@ enable = all
# - locally-enabled (or when checks are (re-)enabled)
# - suppressed-message (...a different message when I disable one?)
disable =
bad-continuation,
missing-docstring,
too-few-public-methods,
too-many-instance-attributes,
@ -40,6 +50,11 @@ ignored-classes = APIv0, venusian.AttachInfo
# - R0201 - method could be a function (for @pre_load-type methods)
ignore = R0201
[pylama:tildes/views/api/web/*]
# ignored checks for web API specifically:
# - C0103 - invalid function names (endpoints can have very long ones)
ignore = C0103
[pylama:tests/*]
# ignored checks for tests specifically:
# - D100 - missing module-level docstrings

5
tildes/requirements-to-freeze.txt

@ -3,9 +3,9 @@ alembic
amqpy
argon2_cffi
astroid==1.5.3 # pylama has issues with pylint 1.8.1
black
bleach
boussole
click==5.1 # boussole needs < 6.0
click
cornice
freezegun
gunicorn
@ -29,6 +29,7 @@ pyramid-tm
pyramid-webassets
pytest
pytest-mock
PyYAML # needs to be installed separately for webassets
SQLAlchemy
SQLAlchemy-Utils
stripe

29
tildes/requirements.txt

@ -1,21 +1,19 @@
ago==0.0.92
alembic==1.0.0
amqpy==0.13.1
argh==0.26.2
appdirs==1.4.3
argon2-cffi==18.1.0
astroid==1.5.3
atomicwrites==1.1.5
attrs==18.1.0
backcall==0.1.0
beautifulsoup4==4.6.0
beautifulsoup4==4.6.3
black==18.6b4
bleach==2.1.3
boussole==1.2.3
certifi==2018.4.16
cffi==1.11.5
chardet==3.0.4
click==5.1
colorama==0.3.9
colorlog==3.1.4
click==6.7
cornice==3.4.0
decorator==4.3.0
freezegun==0.3.10
@ -23,35 +21,32 @@ gunicorn==19.9.0
html5lib==1.0.1
hupper==1.3
idna==2.7
ipython==6.4.0
ipython==6.5.0
ipython-genutils==0.2.0
isort==4.3.4
jedi==0.12.1
Jinja2==2.10
lazy-object-proxy==1.3.1
libsass==0.14.5
Mako==1.0.7
MarkupSafe==1.0
marshmallow==2.15.3
marshmallow==2.15.4
mccabe==0.6.1
more-itertools==4.2.0
more-itertools==4.3.0
mypy==0.620
mypy-extensions==0.3.0
parso==0.3.1
PasteDeploy==1.5.2
pathtools==0.1.2
pexpect==4.6.0
pickleshare==0.7.4
plaster==1.0
plaster-pastedeploy==0.6
pluggy==0.6.0
prometheus-client==0.3.0
pluggy==0.7.1
prometheus-client==0.3.1
prompt-toolkit==1.0.15
psycopg2==2.7.5
ptyprocess==0.6.0
publicsuffix2==2.20160818
py==1.5.4
pyaml==17.12.1
pycodestyle==2.4.0
pycparser==2.18
pydocstyle==2.1.1
@ -68,7 +63,7 @@ pyramid-mako==1.0.2
pyramid-session-redis==1.4.1
pyramid-tm==2.2
pyramid-webassets==0.9
pytest==3.6.3
pytest==3.7.1
pytest-mock==1.10.0
python-dateutil==2.7.3
python-editor==1.0.3
@ -82,9 +77,10 @@ six==1.11.0
snowballstemmer==1.2.1
SQLAlchemy==1.2.10
SQLAlchemy-Utils==0.33.3
stripe==2.0.1
stripe==2.4.0
testing.common.database==2.0.3
testing.redis==1.1.1
toml==0.9.4
traitlets==4.3.2
transaction==2.2.1
translationstring==1.3
@ -92,7 +88,6 @@ typed-ast==1.1.0
urllib3==1.23
venusian==1.1.0
waitress==1.1.0
watchdog==0.8.3
wcwidth==0.1.7
webargs==4.0.0
webassets==0.12.1

104
tildes/scripts/breached_passwords.py

@ -1,17 +1,15 @@
"""Command-line tools for managing a breached-passwords bloom filter.
This tool will help with creating and updating a bloom filter in Redis (using
ReBloom: https://github.com/RedisLabsModules/rebloom) to hold hashes for
passwords that have been revealed through data breaches (to prevent users from
using these passwords here). The dumps are likely primarily sourced from Troy
Hunt's "Pwned Passwords" files:
This tool will help with creating and updating a bloom filter in Redis (using ReBloom:
https://github.com/RedisLabsModules/rebloom) to hold hashes for passwords that have been
revealed through data breaches (to prevent users from using these passwords here). The
dumps are likely primarily sourced from Troy Hunt's "Pwned Passwords" files:
https://haveibeenpwned.com/Passwords
Specifically, the commands in this tool allow building the bloom filter
somewhere else, then the RDB file can be transferred to the production server.
Note that it is expected that a separate redis server instance is running
solely for holding this bloom filter. Replacing the RDB file will result in all
other keys being lost.
Specifically, the commands in this tool allow building the bloom filter somewhere else,
then the RDB file can be transferred to the production server. Note that it is expected
that a separate redis server instance is running solely for holding this bloom filter.
Replacing the RDB file will result in all other keys being lost.
Expected usage of this tool should look something like:
@ -20,8 +18,8 @@ On the machine building the bloom filter:
python breached_passwords.py addhashes pwned-passwords-1.0.txt
python breached_passwords.py addhashes pwned-passwords-update-1.txt
Then the RDB file can simply be transferred to the production server,
overwriting any previous RDB file.
Then the RDB file can simply be transferred to the production server, overwriting any
previous RDB file.
"""
@ -46,11 +44,11 @@ def generate_redis_protocol(*elements: Any) -> str:
Based on the example Ruby code from
https://redis.io/topics/mass-insert#generating-redis-protocol
"""
command = f'*{len(elements)}\r\n'
command = f"*{len(elements)}\r\n"
for element in elements:
element = str(element)
command += f'${len(element)}\r\n{element}\r\n'
command += f"${len(element)}\r\n{element}\r\n"
return command
@ -65,100 +63,92 @@ def validate_init_error_rate(ctx: Any, param: Any, value: Any) -> float:
"""Validate the --error-rate arg for the init command."""
# pylint: disable=unused-argument
if not 0 < value < 1:
raise click.BadParameter('error rate must be a float between 0 and 1')
raise click.BadParameter("error rate must be a float between 0 and 1")
return value
@cli.command(help='Initialize a new empty bloom filter')
@cli.command(help="Initialize a new empty bloom filter")
@click.option(
'--estimate',
"--estimate",
required=True,
type=int,
help='Expected number of passwords that will be added',
help="Expected number of passwords that will be added",
)
@click.option(
'--error-rate',
"--error-rate",
default=0.01,
show_default=True,
help='Bloom filter desired false positive ratio',
help="Bloom filter desired false positive ratio",
callback=validate_init_error_rate,
)
@click.confirmation_option(
prompt='Are you sure you want to clear any existing bloom filter?',
prompt="Are you sure you want to clear any existing bloom filter?"
)
def init(estimate: int, error_rate: float) -> None:
"""Initialize a new bloom filter (destroying any existing one).
It generally shouldn't be necessary to re-init a new bloom filter very
often with this command, only if the previous one was created with too low
of an estimate for number of passwords, or to change to a different false
positive rate. For choosing an estimate value, according to the ReBloom
documentation: "Performance will begin to degrade after adding more items
than this number. The actual degradation will depend on how far the limit
has been exceeded. Performance will degrade linearly as the number of
entries grow exponentially."
It generally shouldn't be necessary to re-init a new bloom filter very often with
this command, only if the previous one was created with too low of an estimate for
number of passwords, or to change to a different false positive rate. For choosing
an estimate value, according to the ReBloom documentation: "Performance will begin
to degrade after adding more items than this number. The actual degradation will
depend on how far the limit has been exceeded. Performance will degrade linearly as
the number of entries grow exponentially."
"""
REDIS.delete(BREACHED_PASSWORDS_BF_KEY)
# BF.RESERVE {key} {error_rate} {size}
REDIS.execute_command(
'BF.RESERVE',
BREACHED_PASSWORDS_BF_KEY,
error_rate,
estimate,
)
REDIS.execute_command("BF.RESERVE", BREACHED_PASSWORDS_BF_KEY, error_rate, estimate)
click.echo(
'Initialized bloom filter with expected size of {:,} and false '
'positive rate of {}%'
.format(estimate, error_rate * 100)
"Initialized bloom filter with expected size of {:,} and false "
"positive rate of {}%".format(estimate, error_rate * 100)
)
@cli.command(help='Add hashes from a file to the bloom filter')
@click.argument('filename', type=click.Path(exists=True, dir_okay=False))
@cli.command(help="Add hashes from a file to the bloom filter")
@click.argument("filename", type=click.Path(exists=True, dir_okay=False))
def addhashes(filename: str) -> None:
"""Add all hashes from a file to the bloom filter.
This uses the method of generating commands in Redis protocol and feeding
them into an instance of `redis-cli --pipe`, as recommended in
This uses the method of generating commands in Redis protocol and feeding them into
an instance of `redis-cli --pipe`, as recommended in
https://redis.io/topics/mass-insert
"""
# make sure the key exists and is a bloom filter
try:
REDIS.execute_command('BF.DEBUG', BREACHED_PASSWORDS_BF_KEY)
REDIS.execute_command("BF.DEBUG", BREACHED_PASSWORDS_BF_KEY)
except ResponseError:
click.echo('Bloom filter is not set up properly - run init first.')
click.echo("Bloom filter is not set up properly - run init first.")
raise click.Abort
# call wc to count the number of lines in the file for the progress bar
click.echo('Determining hash count...')
result = subprocess.run(['wc', '-l', filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b' ')[0])
click.echo("Determining hash count...")
result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
line_count = int(result.stdout.split(b" ")[0])
progress_bar: Any = click.progressbar(length=line_count)
update_interval = 100_000
click.echo('Adding {:,} hashes to bloom filter...'.format(line_count))
click.echo("Adding {:,} hashes to bloom filter...".format(line_count))
redis_pipe = subprocess.Popen(
['redis-cli', '-s', BREACHED_PASSWORDS_REDIS_SOCKET, '--pipe'],
["redis-cli", "-s", BREACHED_PASSWORDS_REDIS_SOCKET, "--pipe"],
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
encoding='utf-8',
encoding="utf-8",
)
for count, line in enumerate(open(filename), start=1):
hashval = line.strip().lower()
# the Pwned Passwords hash lists now have a frequency count for each
# hash, which is separated from the hash with a colon, so we need to
# handle that if it's present
hashval = hashval.split(':')[0]
# the Pwned Passwords hash lists now have a frequency count for each hash, which
# is separated from the hash with a colon, so we need to handle that if it's
# present
hashval = hashval.split(":")[0]
command = generate_redis_protocol(
'BF.ADD', BREACHED_PASSWORDS_BF_KEY, hashval)
command = generate_redis_protocol("BF.ADD", BREACHED_PASSWORDS_BF_KEY, hashval)
redis_pipe.stdin.write(command)
if count % update_interval == 0:
@ -173,5 +163,5 @@ def addhashes(filename: str) -> None:
progress_bar.render_finish()
if __name__ == '__main__':
if __name__ == "__main__":
cli()

67
tildes/scripts/clean_private_data.py

@ -24,8 +24,8 @@ RETENTION_PERIOD = timedelta(days=30)
def clean_all_data(config_path: str) -> None:
"""Clean all private/deleted data.
This should generally be the only function called in most cases, and will
initiate the full cleanup process.
This should generally be the only function called in most cases, and will initiate
the full cleanup process.
"""
db_session = get_session_from_config(config_path)
@ -33,22 +33,17 @@ def clean_all_data(config_path: str) -> None:
cleaner.clean_all()
class DataCleaner():
class DataCleaner:
"""Container class for all methods related to cleaning up old data."""
def __init__(
self,
db_session: Session,
retention_period: timedelta,
) -> None:
def __init__(self, db_session: Session, retention_period: timedelta) -> None:
"""Create a new DataCleaner."""
self.db_session = db_session
self.retention_cutoff = datetime.now() - retention_period
def clean_all(self) -> None:
"""Call all the cleanup functions."""
logging.info(
f'Cleaning up all data (retention cutoff {self.retention_cutoff})')
logging.info(f"Cleaning up all data (retention cutoff {self.retention_cutoff})")
self.delete_old_log_entries()
self.delete_old_topic_visits()
@ -59,8 +54,8 @@ class DataCleaner():
def delete_old_log_entries(self) -> None:
"""Delete all log entries older than the retention cutoff.
Note that this will also delete all entries from the child tables that
inherit from Log (LogTopics, etc.).
Note that this will also delete all entries from the child tables that inherit
from Log (LogTopics, etc.).
"""
deleted = (
self.db_session.query(Log)
@ -68,7 +63,7 @@ class DataCleaner():
.delete(synchronize_session=False)
)
self.db_session.commit()
logging.info(f'Deleted {deleted} old log entries.')
logging.info(f"Deleted {deleted} old log entries.")
def delete_old_topic_visits(self) -> None:
"""Delete all topic visits older than the retention cutoff."""
@ -78,13 +73,13 @@ class DataCleaner():
.delete(synchronize_session=False)
)
self.db_session.commit()
logging.info(f'Deleted {deleted} old topic visits.')
logging.info(f"Deleted {deleted} old topic visits.")
def clean_old_deleted_comments(self) -> None:
"""Clean the data of old deleted comments.
Change the comment's author to the "unknown user" (id 0), and delete
its contents.
Change the comment's author to the "unknown user" (id 0), and delete its
contents.
"""
updated = (
self.db_session.query(Comment)
@ -92,20 +87,19 @@ class DataCleaner():
Comment.deleted_time <= self.retention_cutoff, # type: ignore
Comment.user_id != 0,
)
.update({
'user_id': 0,
'markdown': '',
'rendered_html': '',
}, synchronize_session=False)
.update(
{"user_id": 0, "markdown": "", "rendered_html": ""},
synchronize_session=False,
)
)
self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted comments.')
logging.info(f"Cleaned {updated} old deleted comments.")
def clean_old_deleted_topics(self) -> None:
"""Clean the data of old deleted topics.
Change the topic's author to the "unknown user" (id 0), and delete its
title, contents, tags, and metadata.
Change the topic's author to the "unknown user" (id 0), and delete its title,
contents, tags, and metadata.
"""
updated = (
self.db_session.query(Topic)
@ -113,16 +107,19 @@ class DataCleaner():
Topic.deleted_time <= self.retention_cutoff, # type: ignore
Topic.user_id != 0,
)
.update({
'user_id': 0,
'title': '',
'topic_type': 'TEXT',
'markdown': None,
'rendered_html': None,
'link': None,
'content_metadata': None,
'_tags': [],
}, synchronize_session=False)
.update(
{
"user_id": 0,
"title": "",
"topic_type": "TEXT",
"markdown": None,
"rendered_html": None,
"link": None,
"content_metadata": None,
"_tags": [],
},
synchronize_session=False,
)
)
self.db_session.commit()
logging.info(f'Cleaned {updated} old deleted topics.')
logging.info(f"Cleaned {updated} old deleted topics.")

54
tildes/scripts/initialize_db.py

@ -15,27 +15,24 @@ from tildes.models.log import Log
from tildes.models.user import User
def initialize_db(
config_path: str,
alembic_config_path: Optional[str] = None,
) -> None:
def initialize_db(config_path: str, alembic_config_path: Optional[str] = None) -> None:
"""Load the app config and create the database tables."""
db_session = get_session_from_config(config_path)
engine = db_session.bind
create_tables(engine)
run_sql_scripts_in_dir('sql/init/', engine)
run_sql_scripts_in_dir("sql/init/", engine)
# if an Alembic config file wasn't specified, assume it's alembic.ini in
# the same directory
# if an Alembic config file wasn't specified, assume it's alembic.ini in the same
# directory
if not alembic_config_path:
path = os.path.split(config_path)[0]
alembic_config_path = os.path.join(path, 'alembic.ini')
alembic_config_path = os.path.join(path, "alembic.ini")
# mark current Alembic revision in db so migrations start from this point
alembic_cfg = Config(alembic_config_path)
command.stamp(alembic_cfg, 'head')
command.stamp(alembic_cfg, "head")
def create_tables(connectable: Connectable) -> None:
@ -44,7 +41,8 @@ def create_tables(connectable: Connectable) -> None:
excluded_tables = Log.INHERITED_TABLES
tables = [
table for table in DatabaseModel.metadata.tables.values()
table
for table in DatabaseModel.metadata.tables.values()
if table.name not in excluded_tables
]
DatabaseModel.metadata.create_all(connectable, tables=tables)
@ -53,29 +51,31 @@ def create_tables(connectable: Connectable) -> None:
def run_sql_scripts_in_dir(path: str, engine: Engine) -> None:
"""Run all sql scripts in a directory."""
for root, _, files in os.walk(path):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files:
subprocess.call([
'psql',
'-U', engine.url.username,
'-f', os.path.join(root, sql_file),
engine.url.database,
])
subprocess.call(
[
"psql",
"-U",
engine.url.username,
"-f",
os.path.join(root, sql_file),
engine.url.database,
]
)
def insert_dev_data(config_path: str) -> None:
"""Load the app config and insert some "starter" data for a dev version."""
session = get_session_from_config(config_path)
session.add_all([
User('TestUser', 'password'),
Group(
'testing',
'An automatically created group to use for testing purposes',
),
])
session.add_all(
[
User("TestUser", "password"),
Group(
"testing", "An automatically created group to use for testing purposes"
),
]
)
session.commit()

6
tildes/setup.py

@ -4,11 +4,11 @@ from setuptools import find_packages, setup
setup(
name='tildes',
version='0.1',
name="tildes",
version="0.1",
packages=find_packages(),
entry_points="""
[paste.app_factory]
main = tildes:main
"""
""",
)

107
tildes/tests/conftest.py

@ -18,7 +18,7 @@ from tildes.models.user import User
# include the fixtures defined in fixtures.py
pytest_plugins = ['tests.fixtures']
pytest_plugins = ["tests.fixtures"]
class NestedSessionWrapper(Session):
@ -40,25 +40,25 @@ class NestedSessionWrapper(Session):
super().rollback()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def pyramid_config():
"""Set up the Pyramid environment."""
settings = get_appsettings('development.ini')
settings = get_appsettings("development.ini")
config = testing.setUp(settings=settings)
config.include('tildes.auth')
config.include("tildes.auth")
yield config
testing.tearDown()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_db_session(pyramid_config):
"""Handle setup and teardown of test database for testing session."""
# read the database url from the pyramid INI file, and replace the db name
sqlalchemy_url = pyramid_config.registry.settings['sqlalchemy.url']
sqlalchemy_url = pyramid_config.registry.settings["sqlalchemy.url"]
parsed_url = make_url(sqlalchemy_url)
parsed_url.database = 'tildes_test'
parsed_url.database = "tildes_test"
engine = create_engine(parsed_url)
session_factory = sessionmaker(bind=engine)
@ -66,21 +66,18 @@ def overall_db_session(pyramid_config):
create_tables(session.connection())
# SQL init scripts need to be executed "manually" instead of using psql
# like the normal database init process does, since the tables only exist
# inside this transaction
init_scripts_dir = 'sql/init/'
# SQL init scripts need to be executed "manually" instead of using psql like the
# normal database init process does, since the tables only exist inside this
# transaction
init_scripts_dir = "sql/init/"
for root, _, files in os.walk(init_scripts_dir):
sql_files = [
filename for filename in files
if filename.endswith('.sql')
]
sql_files = [filename for filename in files if filename.endswith(".sql")]
for sql_file in sql_files:
with open(os.path.join(root, sql_file)) as current_file:
session.execute(current_file.read())
# convert the Session to the wrapper class to enforce staying inside
# nested transactions in the test functions
# convert the Session to the wrapper class to enforce staying inside nested
# transactions in the test functions
session.__class__ = NestedSessionWrapper
yield session
@ -90,7 +87,7 @@ def overall_db_session(pyramid_config):
session.rollback()
@fixture(scope='session')
@fixture(scope="session")
def sdb(overall_db_session):
"""Testing-session-level db session with a nested transaction."""
overall_db_session.begin_nested()
@ -100,7 +97,7 @@ def sdb(overall_db_session):
overall_db_session.rollback_all_nested()
@fixture(scope='function')
@fixture(scope="function")
def db(overall_db_session):
"""Function-level db session with a nested transaction."""
overall_db_session.begin_nested()
@ -110,25 +107,23 @@ def db(overall_db_session):
overall_db_session.rollback_all_nested()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def overall_redis_session():
"""Create a session-level connection to a temporary redis server."""
# list of redis modules that need to be loaded (would be much nicer to do
# this automatically somehow, maybe reading from the real redis.conf?)
redis_modules = [
'/opt/redis-cell/libredis_cell.so',
]
# list of redis modules that need to be loaded (would be much nicer to do this
# automatically somehow, maybe reading from the real redis.conf?)
redis_modules = ["/opt/redis-cell/libredis_cell.so"]
with RedisServer() as temp_redis_server:
redis = StrictRedis(**temp_redis_server.dsn())
for module in redis_modules:
redis.execute_command('MODULE LOAD', module)
redis.execute_command("MODULE LOAD", module)
yield redis
@fixture(scope='function')
@fixture(scope="function")
def redis(overall_redis_session):
"""Create a function-level redis connection that wipes the db after use."""
yield overall_redis_session
@ -136,87 +131,87 @@ def redis(overall_redis_session):
overall_redis_session.flushdb()
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user(sdb):
"""Create a user named 'SessionUser' in the db for test session."""
# note that some tests may depend on this username/password having these
# specific values, so make sure to search for and update those tests if you
# change the username or password for any reason
user = User('SessionUser', 'session user password')
# note that some tests may depend on this username/password having these specific
# values, so make sure to search for and update those tests if you change the
# username or password for any reason
user = User("SessionUser", "session user password")
sdb.add(user)
sdb.commit()
yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_user2(sdb):
"""Create a second user named 'OtherUser' in the db for test session.
This is useful for cases where two different users are needed, such as
when testing private messages.
This is useful for cases where two different users are needed, such as when testing
private messages.
"""
user = User('OtherUser', 'other user password')
user = User("OtherUser", "other user password")
sdb.add(user)
sdb.commit()
yield user
@fixture(scope='session', autouse=True)
@fixture(scope="session", autouse=True)
def session_group(sdb):
"""Create a group named 'sessiongroup' in the db for test session."""
group = Group('sessiongroup')
group = Group("sessiongroup")
sdb.add(group)
sdb.commit()
yield group
@fixture(scope='session')
@fixture(scope="session")
def base_app(overall_redis_session, sdb):
"""Configure a base WSGI app that webtest can create TestApps based on."""
testing_app = get_app('development.ini')
testing_app = get_app("development.ini")
# replace the redis connection used by the redis-sessions library with a
# connection to the temporary server for this test session
# replace the redis connection used by the redis-sessions library with a connection
# to the temporary server for this test session
testing_app.app.registry._redis_sessions = overall_redis_session
def redis_factory(request):
# pylint: disable=unused-argument
return overall_redis_session
testing_app.app.registry['redis_connection_factory'] = redis_factory
# replace the session factory function with one that will return the
# testing db session (inside a nested transaction)
testing_app.app.registry["redis_connection_factory"] = redis_factory
# replace the session factory function with one that will return the testing db
# session (inside a nested transaction)
def session_factory():
return sdb
testing_app.app.registry['db_session_factory'] = session_factory
testing_app.app.registry["db_session_factory"] = session_factory
yield testing_app
@fixture(scope='session')
@fixture(scope="session")
def webtest(base_app):
"""Create a webtest TestApp and log in as the SessionUser account in it."""
# create the TestApp - note that specifying wsgi.url_scheme is necessary
# so that the secure cookies from the session library will work
# create the TestApp - note that specifying wsgi.url_scheme is necessary so that the
# secure cookies from the session library will work
app = TestApp(
base_app,
extra_environ={'wsgi.url_scheme': 'https'},
cookiejar=CookieJar(),
base_app, extra_environ={"wsgi.url_scheme": "https"}, cookiejar=CookieJar()
)
# fetch the login page, fill in the form, and submit it (sets the cookie)
login_page = app.get('/login')
login_page.form['username'] = 'SessionUser'
login_page.form['password'] = 'session user password'
login_page = app.get("/login")
login_page.form["username"] = "SessionUser"
login_page.form["password"] = "session user password"
login_page.form.submit()
yield app
@fixture(scope='session')
@fixture(scope="session")
def webtest_loggedout(base_app):
"""Create a logged-out webtest TestApp (no cookies retained)."""
yield TestApp(base_app)

8
tildes/tests/fixtures.py

@ -8,7 +8,8 @@ from tildes.models.topic import Topic
def text_topic(db, session_group, session_user):
"""Create a text topic, delete it as teardown (including comments)."""
new_topic = Topic.create_text_topic(
session_group, session_user, 'A Text Topic', 'the text')
session_group, session_user, "A Text Topic", "the text"
)
db.add(new_topic)
db.commit()
@ -23,7 +24,8 @@ def text_topic(db, session_group, session_user):
def link_topic(db, session_group, session_user):
"""Create a link topic, delete it as teardown (including comments)."""
new_topic = Topic.create_link_topic(
session_group, session_user, 'A Link Topic', 'http://example.com')
session_group, session_user, "A Link Topic", "http://example.com"
)
db.add(new_topic)
db.commit()
@ -43,7 +45,7 @@ def topic(text_topic):
@fixture
def comment(db, session_user, topic):
"""Create a comment in the database, delete it as teardown."""
new_comment = Comment(topic, session_user, 'A comment')
new_comment = Comment(topic, session_user, "A comment")
db.add(new_comment)
db.commit()

64
tildes/tests/test_comment.py

@ -1,81 +1,73 @@
from datetime import timedelta
from freezegun import freeze_time
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
from tildes.enums import CommentSortOption
from tildes.lib.datetime import utc_now
from tildes.models.comment import (
Comment,
CommentTree,
EDIT_GRACE_PERIOD,
)
from tildes.models.comment import Comment, CommentTree, EDIT_GRACE_PERIOD
from tildes.schemas.comment import CommentSchema
from tildes.schemas.fields import Markdown
def test_comment_creation_validates_schema(mocker, session_user, topic):
"""Ensure that comment creation goes through schema validation."""
mocker.spy(CommentSchema, 'load')
mocker.spy(CommentSchema, "load")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
call_args = CommentSchema.load.call_args[0]
assert {'markdown': 'A test comment'} in call_args
assert {"markdown": "A test comment"} in call_args
def test_comment_creation_uses_markdown_field(mocker, session_user, topic):
"""Ensure the Markdown field class is validating new comments."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
Comment(topic, session_user, 'A test comment')
Comment(topic, session_user, "A test comment")
assert Markdown._validate.called
def test_comment_edit_uses_markdown_field(mocker, comment):
"""Ensure editing a comment is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
comment.markdown = 'Some new text after edit'
comment.markdown = "Some new text after edit"
assert Markdown._validate.called
def test_edit_markdown_updates_html(comment):
"""Ensure editing a comment works and the markdown and HTML update."""
comment.markdown = 'Updated comment'
assert 'Updated' in comment.markdown
assert 'Updated' in comment.rendered_html
comment.markdown = "Updated comment"
assert "Updated" in comment.markdown
assert "Updated" in comment.rendered_html
def test_comment_viewing_permission(comment):
"""Ensure that anyone can view a comment by default."""
assert Everyone in principals_allowed_by_permission(comment, 'view')
assert Everyone in principals_allowed_by_permission(comment, "view")
def test_comment_editing_permission(comment):
"""Ensure that only the comment's author can edit it."""
principals = principals_allowed_by_permission(comment, 'edit')
principals = principals_allowed_by_permission(comment, "edit")
assert principals == {comment.user_id}
def test_comment_deleting_permission(comment):
"""Ensure that only the comment's author can delete it."""
principals = principals_allowed_by_permission(comment, 'delete')
principals = principals_allowed_by_permission(comment, "delete")
assert principals == {comment.user_id}
def test_comment_replying_permission(comment):
"""Ensure that any authenticated user can reply to a comment."""
assert Authenticated in principals_allowed_by_permission(comment, 'reply')
assert Authenticated in principals_allowed_by_permission(comment, "reply")
def test_comment_reply_locked_thread_permission(comment):
"""Ensure that only admins can reply in locked threads."""
comment.topic.is_locked = True
assert principals_allowed_by_permission(comment, 'reply') == {'admin'}
assert principals_allowed_by_permission(comment, "reply") == {"admin"}
def test_deleted_comment_permissions_removed(comment):
@ -90,8 +82,8 @@ def test_deleted_comment_permissions_removed(comment):
def test_removed_comment_view_permission(comment):
"""Ensure a removed comment can only be viewed by its author and admins."""
comment.is_removed = True
principals = principals_allowed_by_permission(comment, 'view')
assert principals == {'admin', comment.user_id}
principals = principals_allowed_by_permission(comment, "view")
assert principals == {"admin", comment.user_id}
def test_edit_grace_period(comment):
@ -100,7 +92,7 @@ def test_edit_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert not comment.last_edited_time
@ -111,7 +103,7 @@ def test_edit_after_grace_period(comment):
edit_time = comment.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time):
comment.markdown = 'some new markdown'
comment.markdown = "some new markdown"
assert comment.last_edited_time == utc_now()
@ -123,7 +115,7 @@ def test_multiple_edits_update_time(comment):
for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time):
comment.markdown = f'edit #{minutes}'
comment.markdown = f"edit #{minutes}"
assert comment.last_edited_time == utc_now()
@ -134,8 +126,8 @@ def test_comment_tree(db, topic, session_user):
sort = CommentSortOption.POSTED
# add two root comments
root = Comment(topic, session_user, 'root')
root2 = Comment(topic, session_user, 'root2')
root = Comment(topic, session_user, "root")
root2 = Comment(topic, session_user, "root2")
all_comments.extend([root, root2])
db.add_all(all_comments)
db.commit()
@ -151,8 +143,8 @@ def test_comment_tree(db, topic, session_user):
assert tree == [root]
# add two replies to the remaining root comment
child = Comment(topic, session_user, '1', parent_comment=root)
child2 = Comment(topic, session_user, '2', parent_comment=root)
child = Comment(topic, session_user, "1", parent_comment=root)
child2 = Comment(topic, session_user, "2", parent_comment=root)
all_comments.extend([child, child2])
db.add_all(all_comments)
db.commit()
@ -165,8 +157,8 @@ def test_comment_tree(db, topic, session_user):
assert child2.replies == []
# add two more replies to the second depth-1 comment
subchild = Comment(topic, session_user, '2a', parent_comment=child2)
subchild2 = Comment(topic, session_user, '2b', parent_comment=child2)
subchild = Comment(topic, session_user, "2a", parent_comment=child2)
subchild2 = Comment(topic, session_user, "2b", parent_comment=child2)
all_comments.extend([subchild, subchild2])
db.add_all(all_comments)
db.commit()

76
tildes/tests/test_comment_user_mentions.py

@ -3,10 +3,7 @@ from pytest import fixture
from sqlalchemy import and_
from tildes.enums import CommentNotificationType
from tildes.models.comment import (
Comment,
CommentNotification,
)
from tildes.models.comment import Comment, CommentNotification
from tildes.models.topic import Topic
from tildes.models.user import User
@ -15,8 +12,8 @@ from tildes.models.user import User
def user_list(db):
"""Create several users."""
users = []
for name in ['foo', 'bar', 'baz']:
user = User(name, 'password')
for name in ["foo", "bar", "baz"]:
user = User(name, "password")
users.append(user)
db.add(user)
db.commit()
@ -30,44 +27,40 @@ def user_list(db):
def test_get_mentions_for_comment(db, user_list, comment):
"""Test that notifications are generated and returned."""
comment.markdown = '@foo @bar. @baz!'
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment.markdown = "@foo @bar. @baz!"
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 3
for index, user in enumerate(user_list):
assert mentions[index].user == user
def test_mention_filtering_parent_comment(
mocker, db, topic, user_list):
def test_mention_filtering_parent_comment(mocker, db, topic, user_list):
"""Test notification filtering for parent comments."""
parent_comment = Comment(topic, user_list[0], 'Comment content.')
parent_comment = Comment(topic, user_list[0], "Comment content.")
parent_comment.user_id = user_list[0].user_id
comment = mocker.Mock(
user_id=user_list[1].user_id,
markdown=f'@{user_list[0].username}',
markdown=f"@{user_list[0].username}",
parent_comment=parent_comment,
)
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
def test_mention_filtering_self_mention(db, user_list, topic):
"""Test notification filtering for self-mentions."""
comment = Comment(topic, user_list[0], f'@{user_list[0]}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
comment = Comment(topic, user_list[0], f"@{user_list[0]}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
def test_mention_filtering_top_level(db, user_list, session_group):
"""Test notification filtering for top-level comments."""
topic = Topic.create_text_topic(
session_group, user_list[0], 'Some title', 'some text')
comment = Comment(topic, user_list[1], f'@{user_list[0].username}')
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
session_group, user_list[0], "Some title", "some text"
)
comment = Comment(topic, user_list[1], f"@{user_list[0].username}")
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert not mentions
@ -75,43 +68,42 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
"""Test that notifications are cleaned up for edits.
Flow:
1. A comment is created by user A that mentions user B. Notifications
are generated, and yield A mentioning B.
1. A comment is created by user A that mentions user B. Notifications are
generated, and yield A mentioning B.
2. The comment is edited to mention C and not B.
3. The comment is edited to mention B and C.
4. The comment is deleted.
"""
# 1
comment = Comment(topic, user_list[0], f'@{user_list[1].username}')
comment = Comment(topic, user_list[0], f"@{user_list[1].username}")
db.add(comment)
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1
assert mentions[0].user == user_list[1]
db.add_all(mentions)
db.commit()
# 2
comment.markdown = f'@{user_list[2].username}'
comment.markdown = f"@{user_list[2].username}"
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 1
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert len(to_delete) == 1
assert mentions == to_add
assert to_delete[0].user.username == user_list[1].username
# 3
comment.markdown = f'@{user_list[1].username} @{user_list[2].username}'
comment.markdown = f"@{user_list[1].username} @{user_list[2].username}"
db.commit()
mentions = CommentNotification.get_mentions_for_comment(
db, comment)
mentions = CommentNotification.get_mentions_for_comment(db, comment)
assert len(mentions) == 2
to_delete, to_add = CommentNotification.prevent_duplicate_notifications(
db, comment, mentions)
db, comment, mentions
)
assert not to_delete
assert len(to_add) == 1
@ -120,9 +112,13 @@ def test_prevent_duplicate_notifications(db, user_list, topic):
db.commit()
notifications = (
db.query(CommentNotification.user_id)
.filter(and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
)).all())
.filter(
and_(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
)
.all()
)
assert not notifications

14
tildes/tests/test_datetime.py

@ -20,40 +20,40 @@ def test_utc_now_accurate():
def test_descriptive_timedelta_basic():
"""Ensure a simple descriptive timedelta works correctly."""
test_time = utc_now() - timedelta(hours=3)
assert descriptive_timedelta(test_time) == '3 hours ago'
assert descriptive_timedelta(test_time) == "3 hours ago"
def test_more_precise_longer_descriptive_timedelta():
"""Ensure a longer time period gets the extra precision level."""
test_time = utc_now() - timedelta(days=2, hours=5)
assert descriptive_timedelta(test_time) == '2 days, 5 hours ago'
assert descriptive_timedelta(test_time) == "2 days, 5 hours ago"
def test_no_small_precision_descriptive_timedelta():
"""Ensure the extra precision doesn't apply to small units."""
test_time = utc_now() - timedelta(days=6, minutes=10)
assert descriptive_timedelta(test_time) == '6 days ago'
assert descriptive_timedelta(test_time) == "6 days ago"
def test_single_precision_below_an_hour():
"""Ensure times under an hour only have one precision level."""
test_time = utc_now() - timedelta(minutes=59, seconds=59)
assert descriptive_timedelta(test_time) == '59 minutes ago'
assert descriptive_timedelta(test_time) == "59 minutes ago"
def test_more_precision_above_an_hour():
"""Ensure the second precision level gets added just above an hour."""
test_time = utc_now() - timedelta(hours=1, minutes=1)
assert descriptive_timedelta(test_time) == '1 hour, 1 minute ago'
assert descriptive_timedelta(test_time) == "1 hour, 1 minute ago"
def test_subsecond_descriptive_timedelta():
"""Ensure time less than a second returns the special phrase."""
test_time = utc_now() - timedelta(microseconds=100)
assert descriptive_timedelta(test_time) == 'a moment ago'
assert descriptive_timedelta(test_time) == "a moment ago"
def test_above_second_descriptive_timedelta():
"""Ensure it starts describing time in seconds above 1 second."""
test_time = utc_now() - timedelta(seconds=1, microseconds=100)
assert descriptive_timedelta(test_time) == '1 second ago'
assert descriptive_timedelta(test_time) == "1 second ago"

41
tildes/tests/test_group.py

@ -3,46 +3,43 @@ from sqlalchemy.exc import IntegrityError
from tildes.models.group import Group
from tildes.schemas.fields import Ltree, SimpleString
from tildes.schemas.group import (
GroupSchema,
is_valid_group_path,
)
from tildes.schemas.group import GroupSchema, is_valid_group_path
def test_empty_path_invalid():
"""Ensure empty group path is invalid."""
assert not is_valid_group_path('')
assert not is_valid_group_path("")
def test_typical_path_valid():
"""Ensure a "normal-looking" group path is valid."""
assert is_valid_group_path('games.video.nintendo_3ds')
assert is_valid_group_path("games.video.nintendo_3ds")
def test_start_with_underscore():
"""Ensure you can't start a path with an underscore."""
assert not is_valid_group_path('_x.y.z')
assert not is_valid_group_path("_x.y.z")
def test_middle_element_start_with_underscore():
"""Ensure a middle path element can't start with an underscore."""
assert not is_valid_group_path('x._y.z')
assert not is_valid_group_path("x._y.z")
def test_end_with_underscore():
"""Ensure you can't end a path with an underscore."""
assert not is_valid_group_path('x.y.z_')
assert not is_valid_group_path("x.y.z_")
def test_middle_element_end_with_underscore():
"""Ensure a middle path element can't end with an underscore."""
assert not is_valid_group_path('x.y_.z')
assert not is_valid_group_path("x.y_.z")
def test_uppercase_letters_invalid():
"""Ensure a group path can't contain uppercase chars."""
assert is_valid_group_path('comp.lang.c')
assert not is_valid_group_path('comp.lang.C')
assert is_valid_group_path("comp.lang.c")
assert not is_valid_group_path("comp.lang.C")
def test_paths_with_invalid_characters():
@ -50,34 +47,34 @@ def test_paths_with_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,?/'
for char in invalid_chars:
path = f'abc{char}xyz'
path = f"abc{char}xyz"
assert not is_valid_group_path(path)
def test_paths_with_unicode_characters():
"""Ensure that paths can't use unicode chars (not comprehensive)."""
for path in ('games.pokémon', 'ポケモン', 'bites.møøse'):
for path in ("games.pokémon", "ポケモン", "bites.møøse"):
assert not is_valid_group_path(path)
def test_creation_validates_schema(mocker):
"""Ensure that group creation goes through expected validation."""
mocker.spy(GroupSchema, 'load')
mocker.spy(Ltree, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(GroupSchema, "load")
mocker.spy(Ltree, "_validate")
mocker.spy(SimpleString, "_validate")
Group('testing', 'with a short description')
Group("testing", "with a short description")
assert GroupSchema.load.called
assert Ltree._validate.call_args[0][1] == 'testing'
assert SimpleString._validate.call_args[0][1] == 'with a short description'
assert Ltree._validate.call_args[0][1] == "testing"
assert SimpleString._validate.call_args[0][1] == "with a short description"
def test_duplicate_group(db):
"""Ensure groups with duplicate paths can't be created."""
original = Group('twins')
original = Group("twins")
db.add(original)
duplicate = Group('twins')
duplicate = Group("twins")
db.add(duplicate)
with raises(IntegrityError):

10
tildes/tests/test_id.py

@ -7,12 +7,12 @@ from tildes.lib.id import id_to_id36, id36_to_id
def test_id_to_id36():
"""Make sure an ID->ID36 conversion is correct."""
assert id_to_id36(571049189) == '9fzkdh'
assert id_to_id36(571049189) == "9fzkdh"
def test_id36_to_id():
"""Make sure an ID36->ID conversion is correct."""
assert id36_to_id('x48l4z') == 2002502915
assert id36_to_id("x48l4z") == 2002502915
def test_reversed_conversion_from_id():
@ -23,7 +23,7 @@ def test_reversed_conversion_from_id():
def test_reversed_conversion_from_id36():
"""Make sure an ID36->ID->ID36 conversion returns to original value."""
original = 'h2l4pe'
original = "h2l4pe"
assert id_to_id36(id36_to_id(original)) == original
@ -36,7 +36,7 @@ def test_zero_id_conversion_blocked():
def test_zero_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept zero."""
with raises(ValueError):
id36_to_id('0')
id36_to_id("0")
def test_negative_id_conversion_blocked():
@ -48,4 +48,4 @@ def test_negative_id_conversion_blocked():
def test_negative_id36_conversion_blocked():
"""Ensure the ID36 conversion function doesn't accept negative numbers."""
with raises(ValueError):
id36_to_id('-1')
id36_to_id("-1")

167
tildes/tests/test_markdown.py

@ -3,29 +3,29 @@ from tildes.lib.markdown import convert_markdown_to_safe_html
def test_script_tag_escaped():
"""Ensure that a <script> tag can't get through."""
markdown = '<script>alert()</script>'
markdown = "<script>alert()</script>"
sanitized = convert_markdown_to_safe_html(markdown)
assert '<script>' not in sanitized
assert "<script>" not in sanitized
def test_basic_markdown_unescaped():
"""Test that some common markdown comes through without escaping."""
markdown = (
"# Here's a header.\n\n"
'This chunk of text has **some bold** and *some italics* in it.\n\n'
'A separator will be below this paragraph.\n\n'
'---\n\n'
'* An unordered list item\n'
'* Another list item\n\n'
'> This should be a quote.\n\n'
' And a code block\n\n'
'Also some `inline code` and [a link](http://example.com).\n\n'
'And a manual break \nbetween lines.\n\n'
"This chunk of text has **some bold** and *some italics* in it.\n\n"
"A separator will be below this paragraph.\n\n"
"---\n\n"
"* An unordered list item\n"
"* Another list item\n\n"
"> This should be a quote.\n\n"
" And a code block\n\n"
"Also some `inline code` and [a link](http://example.com).\n\n"
"And a manual break \nbetween lines.\n\n"
)
sanitized = convert_markdown_to_safe_html(markdown)
assert '&lt;' not in sanitized
assert "&lt;" not in sanitized
def test_strikethrough():
@ -33,23 +33,23 @@ def test_strikethrough():
markdown = "This ~should not~ should work"
processed = convert_markdown_to_safe_html(markdown)
assert '<del>' in processed
assert '<a' not in processed
assert "<del>" in processed
assert "<a" not in processed
def test_table():
"""Ensure table markdown works."""
markdown = (
'|Header 1|Header 2|Header 3|\n'
'|--------|-------:|:------:|\n'
'|1 - 1 |1 - 2 |1 - 3 |\n'
'|2 - 1|2 - 2|2 - 3|\n'
"|Header 1|Header 2|Header 3|\n"
"|--------|-------:|:------:|\n"
"|1 - 1 |1 - 2 |1 - 3 |\n"
"|2 - 1|2 - 2|2 - 3|\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert '<table>' in processed
assert processed.count('<tr') == 3
assert processed.count('<td') == 6
assert "<table>" in processed
assert processed.count("<tr") == 3
assert processed.count("<td") == 6
assert 'align="right"' in processed
assert 'align="center"' in processed
@ -57,35 +57,35 @@ def test_table():
def test_deliberate_ordered_list():
"""Ensure a "deliberate" ordered list works."""
markdown = (
'My first line of text.\n\n'
'1. I want\n'
'2. An ordered\n'
'3. List here\n\n'
'A final line.'
"My first line of text.\n\n"
"1. I want\n"
"2. An ordered\n"
"3. List here\n\n"
"A final line."
)
html = convert_markdown_to_safe_html(markdown)
assert '<ol>' in html
assert "<ol>" in html
def test_accidental_ordered_list():
"""Ensure a common "accidental" ordered list gets escaped."""
markdown = (
'What year did this happen?\n\n'
'1975. It was a long time ago.\n\n'
'But I remember it like it was yesterday.'
"What year did this happen?\n\n"
"1975. It was a long time ago.\n\n"
"But I remember it like it was yesterday."
)
html = convert_markdown_to_safe_html(markdown)
assert '<ol' not in html
assert "<ol" not in html
def test_existing_newline_not_doubled():
"""Ensure that the standard markdown line break doesn't result in two."""
markdown = 'A deliberate line \nbreak'
markdown = "A deliberate line \nbreak"
html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == 1
assert html.count("<br") == 1
def test_newline_creates_br():
@ -93,36 +93,31 @@ def test_newline_creates_br():
markdown = "This wouldn't\nnormally work"
html = convert_markdown_to_safe_html(markdown)
assert '<br>' in html
assert "<br>" in html
def test_multiple_newlines():
"""Ensure markdown with multiple newlines has expected result."""
lines = ["One.", "Two.", "Three.", "Four.", "Five."]
markdown = '\n'.join(lines)
markdown = "\n".join(lines)
html = convert_markdown_to_safe_html(markdown)
assert html.count('<br') == len(lines) - 1
assert html.count("<br") == len(lines) - 1
assert all(line in html for line in lines)
def test_newline_in_code_block():
"""Ensure newlines in code blocks don't add a <br>."""
markdown = (
'```\n'
'def testing_for_newlines():\n'
' pass\n'
'```\n'
)
markdown = "```\ndef testing_for_newlines():\n pass\n```\n"
html = convert_markdown_to_safe_html(markdown)
assert '<br' not in html
assert "<br" not in html
def test_http_link_linkified():
"""Ensure that writing an http url results in a link."""
markdown = 'I like http://example.com as an example.'
markdown = "I like http://example.com as an example."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed
@ -130,7 +125,7 @@ def test_http_link_linkified():
def test_https_link_linkified():
"""Ensure that writing an https url results in a link."""
markdown = 'Also, https://example.com should work.'
markdown = "Also, https://example.com should work."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="https://example.com">' in processed
@ -138,7 +133,7 @@ def test_https_link_linkified():
def test_bare_domain_linkified():
"""Ensure that a bare domain results in a link."""
markdown = 'I can just write example.com too.'
markdown = "I can just write example.com too."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com">' in processed
@ -146,7 +141,7 @@ def test_bare_domain_linkified():
def test_link_with_path_linkified():
"""Ensure a link with a path results in a link."""
markdown = 'So http://example.com/a/b_c_d/e too?'
markdown = "So http://example.com/a/b_c_d/e too?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com/a/b_c_d/e">' in processed
@ -154,7 +149,7 @@ def test_link_with_path_linkified():
def test_link_with_query_string_linkified():
"""Ensure a link with a query string results in a link."""
markdown = 'Also http://example.com?something=true works?'
markdown = "Also http://example.com?something=true works?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="http://example.com?something=true">' in processed
@ -162,21 +157,21 @@ def test_link_with_query_string_linkified():
def test_email_address_not_linkified():
"""Ensure that an email address does not get linkified."""
markdown = 'Please contact somebody@example.com about that.'
markdown = "Please contact somebody@example.com about that."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_other_protocol_urls_not_linkified():
"""Ensure some other protocols don't linkify (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols:
markdown = f'Testing {protocol}://example.com for linking'
markdown = f"Testing {protocol}://example.com for linking"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_html_attr_whitelist_violation():
@ -187,23 +182,23 @@ def test_html_attr_whitelist_violation():
)
processed = convert_markdown_to_safe_html(markdown)
assert processed == '<p>test link</p>\n'
assert processed == "<p>test link</p>\n"
def test_a_href_protocol_violation():
"""Ensure link to other protocols removes the link (not comprehensive)."""
protocols = ('data', 'ftp', 'irc', 'mailto', 'news', 'ssh', 'xmpp')
protocols = ("data", "ftp", "irc", "mailto", "news", "ssh", "xmpp")
for protocol in protocols:
markdown = f'Testing [a link]({protocol}://example.com) for linking'
markdown = f"Testing [a link]({protocol}://example.com) for linking"
processed = convert_markdown_to_safe_html(markdown)
assert 'href' not in processed
assert "href" not in processed
def test_group_reference_linkified():
"""Ensure a simple group reference gets linkified."""
markdown = 'Yeah, I saw that in ~books.fantasy yesterday.'
markdown = "Yeah, I saw that in ~books.fantasy yesterday."
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~books.fantasy">' in processed
@ -212,14 +207,14 @@ def test_group_reference_linkified():
def test_multiple_group_references_linkified():
"""Ensure multiple group references are all linkified."""
markdown = (
'I like to keep an eye on:\n\n'
'* ~music.metal\n'
'* ~music.metal.progressive\n'
'* ~music.post_rock\n'
"I like to keep an eye on:\n\n"
"* ~music.metal\n"
"* ~music.metal.progressive\n"
"* ~music.post_rock\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 3
assert processed.count("<a") == 3
def test_invalid_group_reference_not_linkified():
@ -230,20 +225,20 @@ def test_invalid_group_reference_not_linkified():
)
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_approximately_tilde_not_linkified():
"""Ensure a tilde in front of a number doesn't linkify."""
markdown = 'Mix in ~2 cups of flour and ~1.5 tbsp of sugar.'
markdown = "Mix in ~2 cups of flour and ~1.5 tbsp of sugar."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_uppercase_group_ref_links_correctly():
"""Ensure using uppercase in a group ref works but links correctly."""
markdown = 'That was in ~Music.Metal.Progressive'
markdown = "That was in ~Music.Metal.Progressive"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~music.metal.progressive' in processed
@ -260,29 +255,29 @@ def test_existing_link_group_ref_not_replaced():
def test_group_ref_inside_link_not_replaced():
"""Ensure a group ref inside a longer link doesn't get re-linked."""
markdown = 'Found [this band from a ~music.punk post](http://whitelung.ca)'
markdown = "Found [this band from a ~music.punk post](http://whitelung.ca)"
processed = convert_markdown_to_safe_html(markdown)
assert processed.count('<a') == 1
assert processed.count("<a") == 1
assert 'href="/~music.punk"' not in processed
def test_group_ref_inside_pre_ignored():
"""Ensure a group ref inside a <pre> tag doesn't get linked."""
markdown = (
'```\n'
'# This is a code block\n'
'# I found this code on ~comp.lang.python\n'
'```\n'
"```\n"
"# This is a code block\n"
"# I found this code on ~comp.lang.python\n"
"```\n"
)
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_group_ref_inside_other_tags_linkified():
"""Ensure a group ref inside non-ignored tags gets linked."""
markdown = '> Here is **a ~group.reference inside** other stuff'
markdown = "> Here is **a ~group.reference inside** other stuff"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/~group.reference">' in processed
@ -290,7 +285,7 @@ def test_group_ref_inside_other_tags_linkified():
def test_username_reference_linkified():
"""Ensure a basic username reference gets linkified."""
markdown = 'Hey @SomeUser, what do you think of this?'
markdown = "Hey @SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -298,7 +293,7 @@ def test_username_reference_linkified():
def test_u_style_username_ref_linked():
"""Ensure a /u/username reference gets linkified."""
markdown = 'Hey /u/SomeUser, what do you think of this?'
markdown = "Hey /u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">/u/SomeUser</a>' in processed
@ -306,7 +301,7 @@ def test_u_style_username_ref_linked():
def test_u_alt_style_username_ref_linked():
"""Ensure a u/username reference gets linkified."""
markdown = 'Hey u/SomeUser, what do you think of this?'
markdown = "Hey u/SomeUser, what do you think of this?"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">u/SomeUser</a>' in processed
@ -314,15 +309,15 @@ def test_u_alt_style_username_ref_linked():
def test_accidental_u_alt_style_not_linked():
"""Ensure an "accidental" u/ usage won't get linked."""
markdown = 'I think those are caribou/reindeer.'
markdown = "I think those are caribou/reindeer."
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_and_group_refs_linked():
"""Ensure username and group references together get linkified."""
markdown = '@SomeUser makes the best posts in ~some.group for sure'
markdown = "@SomeUser makes the best posts in ~some.group for sure"
processed = convert_markdown_to_safe_html(markdown)
assert '<a href="/user/SomeUser">@SomeUser</a>' in processed
@ -334,16 +329,12 @@ def test_invalid_username_not_linkified():
markdown = "You can't register a username like @_underscores_"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed
def test_username_ref_inside_pre_ignored():
"""Ensure a username ref inside a <pre> tag doesn't get linked."""
markdown = (
'```\n'
'# Code blatantly stolen from @HelpfulGuy on StackOverflow\n'
'```\n'
)
markdown = "```\n# Code blatantly stolen from @HelpfulGuy on StackOverflow\n```\n"
processed = convert_markdown_to_safe_html(markdown)
assert '<a' not in processed
assert "<a" not in processed

22
tildes/tests/test_markdown_field.py

@ -12,53 +12,53 @@ class MarkdownFieldTestSchema(Schema):
def validate_string(string):
"""Validate a string against a standard Markdown field."""
MarkdownFieldTestSchema(strict=True).validate({'markdown': string})
MarkdownFieldTestSchema(strict=True).validate({"markdown": string})
def test_normal_text_validates():
"""Ensure some "normal-looking" markdown validates."""
validate_string(
"Here's some markdown.\n\n"
'It has **a bit of bold**, [a link](http://example.com)\n'
'> And `some code` in a blockquote'
"It has **a bit of bold**, [a link](http://example.com)\n"
"> And `some code` in a blockquote"
)
def test_changing_max_length():
"""Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate
assert Markdown()._validate(test_string) is None
# but fails if you set a too-short max_length
with raises(ValidationError):
Markdown(max_length=len(test_string)-1)._validate(test_string)
Markdown(max_length=len(test_string) - 1)._validate(test_string)
def test_extremely_long_string():
"""Ensure an extremely long string fails validation."""
with raises(ValidationError):
validate_string('A' * 100_000)
validate_string("A" * 100_000)
def test_empty_string():
"""Ensure an empty string fails validation."""
with raises(ValidationError):
validate_string('')
validate_string("")
def test_all_whitespace_string():
"""Ensure a string that's all whitespace chars fails validation."""
with raises(ValidationError):
validate_string(' \n \n\r\n \t ')
validate_string(" \n \n\r\n \t ")
def test_carriage_returns_stripped():
"""Ensure loading a value strips out carriage returns from the string."""
test_string = 'some\r\nreturns\r\nin\nhere'
test_string = "some\r\nreturns\r\nin\nhere"
schema = MarkdownFieldTestSchema(strict=True)
result = schema.load({'markdown': test_string})
result = schema.load({"markdown": test_string})
assert '\r' not in result.data['markdown']
assert "\r" not in result.data["markdown"]

36
tildes/tests/test_messages.py

@ -4,17 +4,15 @@ from pytest import fixture, raises
from tildes.models.message import MessageConversation, MessageReply
from tildes.models.user import User
from tildes.schemas.fields import Markdown, SimpleString
from tildes.schemas.message import (
MessageConversationSchema,
MessageReplySchema,
)
from tildes.schemas.message import MessageConversationSchema, MessageReplySchema
@fixture
def conversation(db, session_user, session_user2):
"""Create a message conversation and delete it as teardown."""
new_conversation = MessageConversation(
session_user, session_user2, 'Subject', 'Message')
session_user, session_user2, "Subject", "Message"
)
db.add(new_conversation)
db.commit()
@ -30,31 +28,31 @@ def conversation(db, session_user, session_user2):
def test_message_conversation_validation(mocker, session_user, session_user2):
"""Ensure a new message conversation goes through expected validation."""
mocker.spy(MessageConversationSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageConversationSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(Markdown, "_validate")
MessageConversation(session_user, session_user2, 'Subject', 'Message')
MessageConversation(session_user, session_user2, "Subject", "Message")
assert MessageConversationSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'Subject'
assert Markdown._validate.call_args[0][1] == 'Message'
assert SimpleString._validate.call_args[0][1] == "Subject"
assert Markdown._validate.call_args[0][1] == "Message"
def test_message_reply_validation(mocker, conversation, session_user2):
"""Ensure a new message reply goes through expected validation."""
mocker.spy(MessageReplySchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(MessageReplySchema, "load")
mocker.spy(Markdown, "_validate")
MessageReply(conversation, session_user2, 'A new reply')
MessageReply(conversation, session_user2, "A new reply")
assert MessageReplySchema.load.called
assert Markdown._validate.call_args[0][1] == 'A new reply'
assert Markdown._validate.call_args[0][1] == "A new reply"
def test_conversation_viewing_permission(conversation):
"""Ensure only the two involved users can view a message conversation."""
principals = principals_allowed_by_permission(conversation, 'view')
principals = principals_allowed_by_permission(conversation, "view")
users = {conversation.sender.user_id, conversation.recipient.user_id}
assert principals == users
@ -70,7 +68,7 @@ def test_conversation_other_user(conversation):
def test_conversation_other_user_invalid(conversation):
"""Ensure that "other user" method fails if the user isn't involved."""
new_user = User('SomeOutsider', 'super amazing password')
new_user = User("SomeOutsider", "super amazing password")
with raises(ValueError):
assert conversation.other_user(new_user)
@ -82,7 +80,7 @@ def test_replies_affect_num_replies(conversation, db):
# add replies and ensure each one increases the count
for num in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply)
db.commit()
db.refresh(conversation)
@ -94,7 +92,7 @@ def test_replies_update_activity_time(conversation, db):
assert conversation.last_activity_time == conversation.created_time
for _ in range(5):
new_reply = MessageReply(conversation, conversation.recipient, 'hi')
new_reply = MessageReply(conversation, conversation.recipient, "hi")
db.add(new_reply)
db.commit()

2
tildes/tests/test_metrics.py

@ -9,4 +9,4 @@ def test_all_metric_names_prefixed():
# this is ugly, but seems to be the "generic" way to get the name
metric_name = metric.describe()[0].name
assert metric_name.startswith('tildes_')
assert metric_name.startswith("tildes_")

40
tildes/tests/test_ratelimit.py

@ -24,13 +24,12 @@ def test_all_rate_limited_action_names_unique():
def test_action_with_all_types_disabled():
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
with raises(ValueError):
RateLimitedAction(
'test', timedelta(hours=1), 5, by_user=False, by_ip=False)
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
def test_check_by_user_id_disabled():
"""Ensure non-by_user RateLimitedAction can't be checked by user_id."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_user=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
with raises(RateLimitError):
action.check_for_user_id(1)
@ -38,10 +37,10 @@ def test_check_by_user_id_disabled():
def test_check_by_ip_disabled():
"""Ensure non-by_ip RateLimitedAction can't be checked by ip."""
action = RateLimitedAction('test', timedelta(hours=1), 5, by_ip=False)
action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
with raises(RateLimitError):
action.check_for_ip('123.123.123.123')
action.check_for_ip("123.123.123.123")
def test_simple_rate_limiting_by_user_id(redis):
@ -51,7 +50,8 @@ def test_simple_rate_limiting_by_user_id(redis):
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
@ -68,7 +68,7 @@ def test_different_user_ids_limited_separately(redis):
limit = 5
user_id = 1
action = RateLimitedAction('test', timedelta(hours=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
# check the action for the first user_id until it's blocked
result = action.check_for_user_id(user_id)
@ -84,7 +84,7 @@ def test_max_burst_defaults_to_half(redis):
limit = 10
user_id = 1
action = RateLimitedAction('test', timedelta(days=1), limit, redis=redis)
action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
# see how many times we can do the action until it gets blocked
count = 0
@ -104,10 +104,11 @@ def test_time_until_retry(redis):
period = timedelta(seconds=60)
limit = 2
# create an action with no burst allowed, which will force the actions to
# be spaced "evenly" across the limit
# create an action with no burst allowed, which will force the actions to be spaced
# "evenly" across the limit
action = RateLimitedAction(
'test', period=period, limit=limit, max_burst=1, redis=redis)
"test", period=period, limit=limit, max_burst=1, redis=redis
)
# first usage should be fine
result = action.check_for_user_id(user_id)
@ -126,7 +127,8 @@ def test_remaining_limit(redis):
# create an action allowing the full limit as a burst
action = RateLimitedAction(
'test', timedelta(days=1), limit, max_burst=limit, redis=redis)
"test", timedelta(days=1), limit, max_burst=limit, redis=redis
)
for count in range(1, limit + 1):
result = action.check_for_user_id(user_id)
@ -136,11 +138,12 @@ def test_remaining_limit(redis):
def test_simple_rate_limiting_by_ip(redis):
"""Ensure simple rate-limiting by IP address is working."""
limit = 5
ip = '123.123.123.123'
ip = "123.123.123.123"
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
'testaction', timedelta(hours=1), limit, max_burst=limit, redis=redis)
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
@ -154,9 +157,9 @@ def test_simple_rate_limiting_by_ip(redis):
def test_check_for_ip_invalid_address():
"""Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.check_for_ip(ip)
@ -164,9 +167,9 @@ def test_check_for_ip_invalid_address():
def test_reset_for_ip_invalid_address():
"""Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
ip = '123.456.789.123'
ip = "123.456.789.123"
action = RateLimitedAction('testaction', timedelta(hours=1), 10)
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.reset_for_ip(ip)
@ -224,6 +227,7 @@ def test_merged_results():
def test_merged_all_allowed():
"""Ensure a merged result from all allowed results is also allowed."""
def random_allowed_result():
"""Return a RateLimitResult with is_allowed=True, otherwise random."""
return RateLimitResult(

26
tildes/tests/test_simplestring_field.py

@ -13,43 +13,43 @@ class SimpleStringTestSchema(Schema):
def process_string(string):
"""Deserialize a string with the field and return the "final" version.
This also works for testing validation since .load() will raise a
ValidationError if an invalid string is attempted.
This also works for testing validation since .load() will raise a ValidationError if
an invalid string is attempted.
"""
schema = SimpleStringTestSchema(strict=True)
result = schema.load({'subject': string})
result = schema.load({"subject": string})
return result.data['subject']
return result.data["subject"]
def test_changing_max_length():
"""Ensure changing the max_length argument works."""
test_string = 'Just some text to try'
test_string = "Just some text to try"
# should normally validate
assert SimpleString()._validate(test_string) is None
# but fails if you set a too-short max_length
with raises(ValidationError):
SimpleString(max_length=len(test_string)-1)._validate(test_string)
SimpleString(max_length=len(test_string) - 1)._validate(test_string)
def test_long_string():
"""Ensure a long string fails validation."""
with raises(ValidationError):
process_string('A' * 10_000)
process_string("A" * 10_000)
def test_empty_string():
"""Ensure an empty string fails validation."""
with raises(ValidationError):
process_string('')
process_string("")
def test_all_whitespace_string():
"""Ensure a string that's entirely whitespace fails validation."""
with raises(ValidationError):
process_string('\n \t \r\n ')
process_string("\n \t \r\n ")
def test_normal_string_untouched():
@ -76,11 +76,11 @@ def test_control_chars_removed():
def test_leading_trailing_spaces_removed():
"""Ensure leading/trailing spaces are removed from the string."""
original = ' Centered! '
assert process_string(original) == 'Centered!'
original = " Centered! "
assert process_string(original) == "Centered!"
def test_consecutive_spaces_collapsed():
"""Ensure runs of consecutive spaces are "collapsed" inside the string."""
original = 'I wanted to space this out'
assert process_string(original) == 'I wanted to space this out'
original = "I wanted to space this out"
assert process_string(original) == "I wanted to space this out"

68
tildes/tests/test_string.py

@ -8,71 +8,69 @@ from tildes.lib.string import (
def test_simple_truncate():
"""Ensure a simple truncation by length works correctly."""
truncated = truncate_string('123456789', 5, overflow_str=None)
assert truncated == '12345'
truncated = truncate_string("123456789", 5, overflow_str=None)
assert truncated == "12345"
def test_simple_truncate_with_overflow():
"""Ensure a simple truncation by length with an overflow string works."""
truncated = truncate_string('123456789', 5)
assert truncated == '12...'
truncated = truncate_string("123456789", 5)
assert truncated == "12..."
def test_truncate_same_length():
"""Ensure truncation doesn't happen if the string is the desired length."""
original = '123456789'
original = "123456789"
assert truncate_string(original, len(original)) == original
def test_truncate_at_char():
"""Ensure truncation at a particular character works."""
original = 'asdf zxcv'
assert truncate_string_at_char(original, ' ') == 'asdf'
original = "asdf zxcv"
assert truncate_string_at_char(original, " ") == "asdf"
def test_truncate_at_last_char():
"""Ensure truncation happens at the last occurrence of the character."""
original = 'as df zx cv'
assert truncate_string_at_char(original, ' ') == 'as df zx'
original = "as df zx cv"
assert truncate_string_at_char(original, " ") == "as df zx"
def test_truncate_at_nonexistent_char():
"""Ensure truncation-at-character doesn't apply if char isn't present."""
original = 'asdfzxcv'
assert truncate_string_at_char(original, ' ') == original
original = "asdfzxcv"
assert truncate_string_at_char(original, " ") == original
def test_truncate_at_multiple_chars():
"""Ensure truncation with multiple characters uses the rightmost one."""
original = 'as-df=zx_cv'
assert truncate_string_at_char(original, '-=') == 'as-df'
original = "as-df=zx_cv"
assert truncate_string_at_char(original, "-=") == "as-df"
def test_truncate_length_and_char():
"""Ensure combined length+char truncation works as expected."""
original = '12345-67890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345'
original = "12345-67890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345"
def test_truncate_length_and_nonexistent_char():
"""Ensure length+char truncation works if the char isn't present."""
original = '1234567890-12345'
truncated = truncate_string(
original, 8, truncate_at_chars='-', overflow_str=None)
assert truncated == '12345678'
original = "1234567890-12345"
truncated = truncate_string(original, 8, truncate_at_chars="-", overflow_str=None)
assert truncated == "12345678"
def test_simple_url_slug_conversion():
"""Ensure that a simple url slug conversion works as expected."""
assert convert_to_url_slug("A Simple Test") == 'a_simple_test'
assert convert_to_url_slug("A Simple Test") == "a_simple_test"
def test_url_slug_with_punctuation():
"""Ensure url slug conversion with punctuation works as expected."""
original = "Here's a string. It has (some) punctuation!"
expected = 'heres_a_string_it_has_some_punctuation'
expected = "heres_a_string_it_has_some_punctuation"
assert convert_to_url_slug(original) == expected
@ -86,37 +84,37 @@ def test_url_slug_with_apostrophes():
def test_url_slug_truncation():
"""Ensure a simple url slug truncates as expected."""
original = "Here's another string to truncate."
assert convert_to_url_slug(original, 15) == 'heres_another'
assert convert_to_url_slug(original, 15) == "heres_another"
def test_multibyte_url_slug():
"""Ensure converting/truncating a slug with encoded characters works."""
original = 'Python ist eine üblicherweise höhere Programmiersprache'
expected = 'python_ist_eine_%C3%BCblicherweise'
original = "Python ist eine üblicherweise höhere Programmiersprache"
expected = "python_ist_eine_%C3%BCblicherweise"
assert convert_to_url_slug(original, 45) == expected
def test_multibyte_conservative_truncation():
"""Ensure truncating a multibyte url slug won't massively shorten it."""
# this string has a comma as the 6th char which will be converted to an
# underscore, so if truncation amount isn't restricted, it would result in
# a 46-char slug instead of the full 100.
original = 'パイソンは、汎用のプログラミング言語である'
# this string has a comma as the 6th char which will be converted to an underscore,
# so if truncation amount isn't restricted, it would result in a 46-char slug
# instead of the full 100.
original = "パイソンは、汎用のプログラミング言語である"
assert len(convert_to_url_slug(original, 100)) == 100
def test_multibyte_whole_character_truncation():
"""Ensure truncation happens at the edge of a multibyte character."""
# each of these characters url-encodes to 3 bytes = 9 characters each, so
# only the first character should be included for all lengths from 9 - 17
original = 'コード'
# each of these characters url-encodes to 3 bytes = 9 characters each, so only the
# first character should be included for all lengths from 9 - 17
original = "コード"
for limit in range(9, 18):
assert convert_to_url_slug(original, limit) == '%E3%82%B3'
assert convert_to_url_slug(original, limit) == "%E3%82%B3"
def test_simple_word_count():
"""Ensure word-counting a simple string works as expected."""
string = 'Here is a simple string of words, nothing fancy.'
string = "Here is a simple string of words, nothing fancy."
assert word_count(string) == 9

36
tildes/tests/test_title.py

@ -7,57 +7,57 @@ from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
@fixture
def title_schema():
"""Fixture for generating a title-only TopicSchema."""
return TopicSchema(only=('title',))
return TopicSchema(only=("title",))
def test_typical_title_valid(title_schema):
"""Test a "normal-looking" title to make sure it's valid."""
title = "[Something] Here's an article that I'm sure 100 people will like."
assert title_schema.validate({'title': title}) == {}
assert title_schema.validate({"title": title}) == {}
def test_too_long_title_invalid(title_schema):
"""Ensure a too-long title is invalid."""
title = 'x' * (TITLE_MAX_LENGTH + 1)
title = "x" * (TITLE_MAX_LENGTH + 1)
with raises(ValidationError):
title_schema.validate({'title': title})
title_schema.validate({"title": title})
def test_empty_title_invalid(title_schema):
"""Ensure an empty title is invalid."""
with raises(ValidationError):
title_schema.validate({'title': ''})
title_schema.validate({"title": ""})
def test_whitespace_only_title_invalid(title_schema):
"""Ensure a whitespace-only title is invalid."""
with raises(ValidationError):
title_schema.validate({'title': ' \n '})
title_schema.validate({"title": " \n "})
def test_whitespace_trimmed(title_schema):
"""Ensure leading/trailing whitespace on a title is removed."""
title = ' actual title '
result = title_schema.load({'title': title})
assert result.data['title'] == 'actual title'
title = " actual title "
result = title_schema.load({"title": title})
assert result.data["title"] == "actual title"
def test_consecutive_whitespace_removed(title_schema):
"""Ensure consecutive whitespace in a title is compressed."""
title = 'sure are \n a lot of spaces'
result = title_schema.load({'title': title})
assert result.data['title'] == 'sure are a lot of spaces'
title = "sure are \n a lot of spaces"
result = title_schema.load({"title": title})
assert result.data["title"] == "sure are a lot of spaces"
def test_unicode_spaces_normalized(title_schema):
"""Test that some unicode space characters are converted to normal ones."""
title = 'some\u2009weird\u00a0spaces\u205fin\u00a0here'
result = title_schema.load({'title': title})
assert result.data['title'] == 'some weird spaces in here'
title = "some\u2009weird\u00a0spaces\u205fin\u00a0here"
result = title_schema.load({"title": title})
assert result.data["title"] == "some weird spaces in here"
def test_unicode_control_chars_removed(title_schema):
"""Test that some unicode control characters are stripped from titles."""
title = 'nothing\u0000strange\u0085going\u009con\u007fhere'
result = title_schema.load({'title': title})
assert result.data['title'] == 'nothingstrangegoingonhere'
title = "nothing\u0000strange\u0085going\u009con\u007fhere"
result = title_schema.load({"title": title})
assert result.data["title"] == "nothingstrangegoingonhere"

44
tildes/tests/test_topic.py

@ -12,41 +12,37 @@ from tildes.schemas.topic import TopicSchema
def test_text_creation_validations(mocker, session_user, session_group):
"""Ensure that text topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(Markdown, '_validate')
mocker.spy(SimpleString, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(Markdown, "_validate")
mocker.spy(SimpleString, "_validate")
Topic.create_text_topic(
session_group, session_user, 'a title', 'the text')
Topic.create_text_topic(session_group, session_user, "a title", "the text")
assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'a title'
assert Markdown._validate.call_args[0][1] == 'the text'
assert SimpleString._validate.call_args[0][1] == "a title"
assert Markdown._validate.call_args[0][1] == "the text"
def test_link_creation_validations(mocker, session_user, session_group):
"""Ensure that link topic creation goes through expected validation."""
mocker.spy(TopicSchema, 'load')
mocker.spy(SimpleString, '_validate')
mocker.spy(URL, '_validate')
mocker.spy(TopicSchema, "load")
mocker.spy(SimpleString, "_validate")
mocker.spy(URL, "_validate")
Topic.create_link_topic(
session_group,
session_user,
'the title',
'http://example.com',
session_group, session_user, "the title", "http://example.com"
)
assert TopicSchema.load.called
assert SimpleString._validate.call_args[0][1] == 'the title'
assert URL._validate.call_args[0][1] == 'http://example.com'
assert SimpleString._validate.call_args[0][1] == "the title"
assert URL._validate.call_args[0][1] == "http://example.com"
def test_text_topic_edit_uses_markdown_field(mocker, text_topic):
"""Ensure editing a text topic is validated by the Markdown field class."""
mocker.spy(Markdown, '_validate')
mocker.spy(Markdown, "_validate")
text_topic.markdown = 'Some new text after edit'
text_topic.markdown = "Some new text after edit"
assert Markdown._validate.called
@ -82,19 +78,19 @@ def test_link_domain_errors_on_text_topic(text_topic):
def test_link_domain_on_link_topic(link_topic):
"""Ensure getting the domain of a link topic works."""
assert link_topic.link_domain == 'example.com'
assert link_topic.link_domain == "example.com"
def test_edit_markdown_errors_on_link_topic(link_topic):
"""Ensure trying to edit the markdown of a link topic is an error."""
with raises(AttributeError):
link_topic.markdown = 'Some new markdown'
link_topic.markdown = "Some new markdown"
def test_edit_markdown_on_text_topic(text_topic):
"""Ensure editing the markdown of a text topic works and updates html."""
original_html = text_topic.rendered_html
text_topic.markdown = 'Some new markdown'
text_topic.markdown = "Some new markdown"
assert text_topic.rendered_html != original_html
@ -104,7 +100,7 @@ def test_edit_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD - one_sec
with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert not text_topic.last_edited_time
@ -115,7 +111,7 @@ def test_edit_after_grace_period(text_topic):
edit_time = text_topic.created_time + EDIT_GRACE_PERIOD + one_sec
with freeze_time(edit_time):
text_topic.markdown = 'some new markdown'
text_topic.markdown = "some new markdown"
assert text_topic.last_edited_time == utc_now()
@ -127,7 +123,7 @@ def test_multiple_edits_update_time(text_topic):
for minutes in range(0, 4):
edit_time = initial_time + timedelta(minutes=minutes)
with freeze_time(edit_time):
text_topic.markdown = f'edit #{minutes}'
text_topic.markdown = f"edit #{minutes}"
assert text_topic.last_edited_time == utc_now()

37
tildes/tests/test_topic_permissions.py

@ -1,13 +1,9 @@
from pyramid.security import (
Authenticated,
Everyone,
principals_allowed_by_permission,
)
from pyramid.security import Authenticated, Everyone, principals_allowed_by_permission
def test_topic_viewing_permission(text_topic):
"""Ensure that anyone can view a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'view')
principals = principals_allowed_by_permission(text_topic, "view")
assert Everyone in principals
@ -15,71 +11,70 @@ def test_deleted_topic_permissions_removed(topic):
"""Ensure that deleted topics lose all permissions except "view"."""
topic.is_deleted = True
assert principals_allowed_by_permission(topic, 'view') == {Everyone}
assert principals_allowed_by_permission(topic, "view") == {Everyone}
all_permissions = [
perm for (_, _, perm) in topic.__acl__() if perm != 'view']
all_permissions = [perm for (_, _, perm) in topic.__acl__() if perm != "view"]
for permission in all_permissions:
assert not principals_allowed_by_permission(topic, permission)
def test_text_topic_editing_permission(text_topic):
"""Ensure a text topic's owner (and nobody else) is able to edit it."""
principals = principals_allowed_by_permission(text_topic, 'edit')
principals = principals_allowed_by_permission(text_topic, "edit")
assert principals == {text_topic.user.user_id}
def test_link_topic_editing_permission(link_topic):
"""Ensure that nobody has edit permission on a link topic."""
principals = principals_allowed_by_permission(link_topic, 'edit')
principals = principals_allowed_by_permission(link_topic, "edit")
assert not principals
def test_topic_deleting_permission(text_topic):
"""Ensure that the topic's owner (and nobody else) is able to delete it."""
principals = principals_allowed_by_permission(text_topic, 'delete')
principals = principals_allowed_by_permission(text_topic, "delete")
assert principals == {text_topic.user.user_id}
def test_topic_view_author_permission(text_topic):
"""Ensure anyone can view a topic's author normally."""
principals = principals_allowed_by_permission(text_topic, 'view_author')
principals = principals_allowed_by_permission(text_topic, "view_author")
assert Everyone in principals
def test_removed_topic_view_author_permission(topic):
"""Ensure only admins and the author can view a removed topic's author."""
topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_author')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_author")
assert principals == {"admin", topic.user_id}
def test_topic_view_content_permission(text_topic):
"""Ensure anyone can view a topic's content normally."""
principals = principals_allowed_by_permission(text_topic, 'view_content')
principals = principals_allowed_by_permission(text_topic, "view_content")
assert Everyone in principals
def test_removed_topic_view_content_permission(topic):
"""Ensure only admins and the author can view a removed topic's content."""
topic.is_removed = True
principals = principals_allowed_by_permission(topic, 'view_content')
assert principals == {'admin', topic.user_id}
principals = principals_allowed_by_permission(topic, "view_content")
assert principals == {"admin", topic.user_id}
def test_topic_comment_permission(text_topic):
"""Ensure authed users have comment perms on a topic by default."""
principals = principals_allowed_by_permission(text_topic, 'comment')
principals = principals_allowed_by_permission(text_topic, "comment")
assert Authenticated in principals
def test_locked_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on locked topics."""
topic.is_locked = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}
def test_removed_topic_comment_permission(topic):
"""Ensure only admins can post (top-level) comments on removed topics."""
topic.is_removed = True
assert principals_allowed_by_permission(topic, 'comment') == {'admin'}
assert principals_allowed_by_permission(topic, "comment") == {"admin"}

22
tildes/tests/test_topic_tags.py

@ -1,34 +1,34 @@
def test_tags_whitespace_stripped(text_topic):
"""Ensure excess whitespace around tags gets stripped."""
text_topic.tags = [' one', 'two ', ' three ']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = [" one", "two ", " three "]
assert text_topic.tags == ["one", "two", "three"]
def test_tag_space_replacement(text_topic):
"""Ensure spaces in tags are converted to underscores internally."""
text_topic.tags = ['one two', 'three four five']
assert text_topic._tags == ['one_two', 'three_four_five']
text_topic.tags = ["one two", "three four five"]
assert text_topic._tags == ["one_two", "three_four_five"]
def test_tag_consecutive_spaces(text_topic):
"""Ensure consecutive spaces/underscores in tags are removed."""
text_topic.tags = ["one two", "three four", "five __ six"]
assert text_topic.tags == ['one two', 'three four', 'five six']
assert text_topic.tags == ["one two", "three four", "five six"]
def test_duplicate_tags_removed(text_topic):
"""Ensure duplicate tags are removed (case-insensitive)."""
text_topic.tags = ['one', 'one', 'One', 'ONE', 'two', 'TWO']
assert text_topic.tags == ['one', 'two']
text_topic.tags = ["one", "one", "One", "ONE", "two", "TWO"]
assert text_topic.tags == ["one", "two"]
def test_empty_tags_removed(text_topic):
"""Ensure empty tags are removed."""
text_topic.tags = ['', ' ', '_', 'one']
assert text_topic.tags == ['one']
text_topic.tags = ["", " ", "_", "one"]
assert text_topic.tags == ["one"]
def test_tags_lowercased(text_topic):
"""Ensure tags get converted to lowercase."""
text_topic.tags = ['ONE', 'Two', 'thRee']
assert text_topic.tags == ['one', 'two', 'three']
text_topic.tags = ["ONE", "Two", "thRee"]
assert text_topic.tags == ["one", "two", "three"]

6
tildes/tests/test_triggers_comments.py

@ -8,7 +8,7 @@ def test_comments_affect_topic_num_comments(session_user, topic, db):
# Insert some comments, ensure each one increments the count
comments = []
for num in range(0, 5):
new_comment = Comment(topic, session_user, 'comment')
new_comment = Comment(topic, session_user, "comment")
comments.append(new_comment)
db.add(new_comment)
db.commit()
@ -62,8 +62,8 @@ def test_remove_sets_removed_time(db, comment):
def test_remove_delete_single_decrement(db, topic, session_user):
"""Ensure that remove+delete doesn't double-decrement num_comments."""
# add 2 comments
comment1 = Comment(topic, session_user, 'Comment 1')
comment2 = Comment(topic, session_user, 'Comment 2')
comment1 = Comment(topic, session_user, "Comment 1")
comment2 = Comment(topic, session_user, "Comment 2")
db.add_all([comment1, comment2])
db.commit()
db.refresh(topic)

22
tildes/tests/test_url.py

@ -5,13 +5,13 @@ from tildes.lib.url import get_domain_from_url
def test_simple_get_domain():
"""Ensure getting the domain from a normal URL works."""
url = 'http://example.com/some/path?query=param&query2=val2'
assert get_domain_from_url(url) == 'example.com'
url = "http://example.com/some/path?query=param&query2=val2"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_non_url():
"""Ensure attempting to get the domain for a non-url is an error."""
url = 'this is not a url'
url = "this is not a url"
with raises(ValueError):
get_domain_from_url(url)
@ -19,27 +19,27 @@ def test_get_domain_non_url():
def test_get_domain_no_scheme():
"""Ensure getting domain on a url with no scheme is an error."""
with raises(ValueError):
get_domain_from_url('example.com/something')
get_domain_from_url("example.com/something")
def test_get_domain_explicit_no_scheme():
"""Ensure getting domain works if url is explicit about lack of scheme."""
assert get_domain_from_url('//example.com/something') == 'example.com'
assert get_domain_from_url("//example.com/something") == "example.com"
def test_get_domain_strip_www():
"""Ensure stripping the "www." from the domain works as expected."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url) == 'example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url) == "example.com"
def test_get_domain_no_strip_www():
"""Ensure stripping the "www." can be disabled."""
url = 'http://www.example.com/a/path/to/something'
assert get_domain_from_url(url, strip_www=False) == 'www.example.com'
url = "http://www.example.com/a/path/to/something"
assert get_domain_from_url(url, strip_www=False) == "www.example.com"
def test_get_domain_subdomain_not_stripped():
"""Ensure a non-www subdomain isn't stripped."""
url = 'http://something.example.com/path/x/y/z'
assert get_domain_from_url(url) == 'something.example.com'
url = "http://something.example.com/path/x/y/z"
assert get_domain_from_url(url) == "something.example.com"

51
tildes/tests/test_user.py

@ -8,55 +8,55 @@ from tildes.schemas.user import PASSWORD_MIN_LENGTH, UserSchema
def test_creation_validates_schema(mocker):
"""Ensure that model creation goes through schema validation."""
mocker.spy(UserSchema, 'validate')
User('testing', 'testpassword')
mocker.spy(UserSchema, "validate")
User("testing", "testpassword")
call_args = [call[0] for call in UserSchema.validate.call_args_list]
expected_args = {'username': 'testing', 'password': 'testpassword'}
expected_args = {"username": "testing", "password": "testpassword"}
assert any(expected_args in call for call in call_args)
def test_too_short_password():
"""Ensure a new user can't be created with a too-short password."""
password = 'x' * (PASSWORD_MIN_LENGTH - 1)
password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError):
User('ShortPasswordGuy', password)
User("ShortPasswordGuy", password)
def test_matching_password_and_username():
"""Ensure a new user can't be created with same username and password."""
with raises(ValidationError):
User('UnimaginativePassword', 'UnimaginativePassword')
User("UnimaginativePassword", "UnimaginativePassword")
def test_username_and_password_differ_in_casing():
"""Ensure a user can't be created with name/pass the same except case."""
with raises(ValidationError):
User('NobodyWillGuess', 'nobodywillguess')
User("NobodyWillGuess", "nobodywillguess")
def test_username_contained_in_password():
"""Ensure a user can't be created with the username in the password."""
with raises(ValidationError):
User('MyUsername', 'iputmyusernameinmypassword')
User("MyUsername", "iputmyusernameinmypassword")
def test_password_contained_in_username():
"""Ensure a user can't be created with the password in the username."""
with raises(ValidationError):
User('PasswordIsVeryGood', 'VeryGood')
User("PasswordIsVeryGood", "VeryGood")
def test_user_password_check():
"""Ensure checking the password for a new user works correctly."""
new_user = User('myusername', 'mypassword')
assert new_user.is_correct_password('mypassword')
new_user = User("myusername", "mypassword")
assert new_user.is_correct_password("mypassword")
def test_duplicate_username(db):
"""Ensure two users with the same name can't be created."""
original = User('Inimitable', 'securepassword')
original = User("Inimitable", "securepassword")
db.add(original)
duplicate = User('Inimitable', 'adifferentpassword')
duplicate = User("Inimitable", "adifferentpassword")
db.add(duplicate)
with raises(IntegrityError):
@ -65,10 +65,10 @@ def test_duplicate_username(db):
def test_duplicate_username_case_insensitive(db):
"""Ensure usernames only differing in casing can't be created."""
test_username = 'test_user'
original = User(test_username.lower(), 'hackproof')
test_username = "test_user"
original = User(test_username.lower(), "hackproof")
db.add(original)
duplicate = User(test_username.upper(), 'sosecure')
duplicate = User(test_username.upper(), "sosecure")
db.add(duplicate)
with raises(IntegrityError):
@ -77,20 +77,20 @@ def test_duplicate_username_case_insensitive(db):
def test_change_password():
"""Ensure changing a user password works as expected."""
new_user = User('A_New_User', 'lovesexsecretgod')
new_user = User("A_New_User", "lovesexsecretgod")
new_user.change_password('lovesexsecretgod', 'lovesexsecretgod1')
new_user.change_password("lovesexsecretgod", "lovesexsecretgod1")
# the old one shouldn't work
assert not new_user.is_correct_password('lovesexsecretgod')
assert not new_user.is_correct_password("lovesexsecretgod")
# the new one should
assert new_user.is_correct_password('lovesexsecretgod1')
assert new_user.is_correct_password("lovesexsecretgod1")
def test_change_password_to_same(session_user):
"""Ensure users can't "change" to the same password."""
password = 'session user password'
password = "session user password"
with raises(ValueError):
session_user.change_password(password, password)
@ -98,18 +98,17 @@ def test_change_password_to_same(session_user):
def test_change_password_wrong_old_one(session_user):
"""Ensure changing password doesn't work if the old one is wrong."""
with raises(ValueError):
session_user.change_password('definitely not right', 'some new one')
session_user.change_password("definitely not right", "some new one")
def test_change_password_too_short(session_user):
"""Ensure users can't change password to a too-short one."""
new_password = 'x' * (PASSWORD_MIN_LENGTH - 1)
new_password = "x" * (PASSWORD_MIN_LENGTH - 1)
with raises(ValidationError):
session_user.change_password('session user password', new_password)
session_user.change_password("session user password", new_password)
def test_change_password_to_username(session_user):
"""Ensure users can't change password to the same as their username."""
with raises(ValidationError):
session_user.change_password(
'session user password', session_user.username)
session_user.change_password("session user password", session_user.username)

16
tildes/tests/test_username.py

@ -10,7 +10,7 @@ from tildes.schemas.user import (
def test_too_short_invalid():
"""Ensure too-short username is invalid."""
length = USERNAME_MIN_LENGTH - 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username)
@ -18,7 +18,7 @@ def test_too_short_invalid():
def test_too_long_invalid():
"""Ensure too-long username is invalid."""
length = USERNAME_MAX_LENGTH + 1
username = 'x' * length
username = "x" * length
assert not is_valid_username(username)
@ -26,22 +26,22 @@ def test_too_long_invalid():
def test_valid_length_range():
"""Ensure the entire range of valid lengths work."""
for length in range(USERNAME_MIN_LENGTH, USERNAME_MAX_LENGTH + 1):
username = 'x' * length
username = "x" * length
assert is_valid_username(username)
def test_consecutive_spacer_chars_invalid():
"""Ensure that a username with consecutive "spacer chars" is invalid."""
spacer_chars = '_-'
spacer_chars = "_-"
for char1, char2 in product(spacer_chars, spacer_chars):
username = f'abc{char1}{char2}xyz'
username = f"abc{char1}{char2}xyz"
assert not is_valid_username(username)
def test_typical_username_valid():
"""Ensure a "normal-looking" username is considered valid."""
assert is_valid_username('someTypical_user-85')
assert is_valid_username("someTypical_user-85")
def test_invalid_characters():
@ -49,11 +49,11 @@ def test_invalid_characters():
invalid_chars = ' ~!@#$%^&*()+={}[]|\\:;"<>,.?/'
for char in invalid_chars:
username = f'abc{char}xyz'
username = f"abc{char}xyz"
assert not is_valid_username(username)
def test_unicode_characters():
"""Ensure that unicode chars can't be included (not comprehensive)."""
for username in ('pokémon', 'ポケモン', 'møøse'):
for username in ("pokémon", "ポケモン", "møøse"):
assert not is_valid_username(username)

10
tildes/tests/test_webassets.py

@ -1,22 +1,22 @@
from webassets.loaders import YAMLLoader
WEBASSETS_ENV = YAMLLoader('webassets.yaml').load_environment()
WEBASSETS_ENV = YAMLLoader("webassets.yaml").load_environment()
def test_scripts_file_first_in_bundle():
"""Ensure that the main scripts.js file will be at the top."""
js_bundle = WEBASSETS_ENV['javascript']
js_bundle = WEBASSETS_ENV["javascript"]
first_filename = js_bundle.resolve_contents()[0][0]
assert first_filename == 'js/scripts.js'
assert first_filename == "js/scripts.js"
def test_styles_file_last_in_bundle():
"""Ensure that the main styles.css file will be at the bottom."""
css_bundle = WEBASSETS_ENV['css']
css_bundle = WEBASSETS_ENV["css"]
last_filename = css_bundle.resolve_contents()[-1][0]
assert last_filename == 'css/styles.css'
assert last_filename == "css/styles.css"

12
tildes/tests/webtests/test_user_page.py

@ -1,13 +1,15 @@
def test_loggedout_username_leak(webtest_loggedout, session_user):
"""Ensure responses from existing and nonexistent users are the same.
Since logged-out users are currently blocked from seeing user pages, this
makes sure that there isn't a data leak where it's possible to tell if a
particular username exists or not.
Since logged-out users are currently blocked from seeing user pages, this makes sure
that there isn't a data leak where it's possible to tell if a particular username
exists or not.
"""
existing_user = webtest_loggedout.get(
'/user/' + session_user.username, expect_errors=True)
"/user/" + session_user.username, expect_errors=True
)
nonexistent_user = webtest_loggedout.get(
'/user/thisdoesntexist', expect_errors=True)
"/user/thisdoesntexist", expect_errors=True
)
assert existing_user.status == nonexistent_user.status

101
tildes/tildes/__init__.py

@ -16,51 +16,47 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
"""Configure and return a Pyramid WSGI application."""
config = Configurator(settings=settings)
config.include('cornice')
config.include('pyramid_session_redis')
config.include('pyramid_webassets')
config.include("cornice")
config.include("pyramid_session_redis")
config.include("pyramid_webassets")
# include database first so the session and querying are available
config.include('tildes.database')
config.include('tildes.auth')
config.include('tildes.jinja')
config.include('tildes.json')
config.include('tildes.routes')
config.include("tildes.database")
config.include("tildes.auth")
config.include("tildes.jinja")
config.include("tildes.json")
config.include("tildes.routes")
config.add_webasset('javascript', Bundle(output='js/tildes.js'))
config.add_webasset(
'javascript-third-party', Bundle(output='js/third_party.js'))
config.add_webasset('css', Bundle(output='css/tildes.css'))
config.add_webasset('site-icons-css', Bundle(output='css/site-icons.css'))
config.add_webasset("javascript", Bundle(output="js/tildes.js"))
config.add_webasset("javascript-third-party", Bundle(output="js/third_party.js"))
config.add_webasset("css", Bundle(output="css/tildes.css"))
config.add_webasset("site-icons-css", Bundle(output="css/site-icons.css"))
config.scan('tildes.views')
config.scan("tildes.views")
config.add_tween('tildes.http_method_tween_factory')
config.add_tween("tildes.http_method_tween_factory")
config.add_request_method(
is_safe_request_method, 'is_safe_method', reify=True)
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True)
# Add the request.redis request method to access a redis connection. This
# is done in a bit of a strange way to support being overridden in tests.
config.registry['redis_connection_factory'] = get_redis_connection
# Add the request.redis request method to access a redis connection. This is done in
# a bit of a strange way to support being overridden in tests.
config.registry["redis_connection_factory"] = get_redis_connection
# pylint: disable=unnecessary-lambda
config.add_request_method(
lambda request: config.registry['redis_connection_factory'](request),
'redis',
lambda request: config.registry["redis_connection_factory"](request),
"redis",
reify=True,
)
# pylint: enable=unnecessary-lambda
config.add_request_method(check_rate_limit, 'check_rate_limit')
config.add_request_method(check_rate_limit, "check_rate_limit")
config.add_request_method(
current_listing_base_url, 'current_listing_base_url')
config.add_request_method(
current_listing_normal_url, 'current_listing_normal_url')
config.add_request_method(current_listing_base_url, "current_listing_base_url")
config.add_request_method(current_listing_normal_url, "current_listing_normal_url")
app = config.make_wsgi_app()
force_port = global_config.get('prefixmiddleware_force_port')
force_port = global_config.get("prefixmiddleware_force_port")
if force_port:
prefixed_app = PrefixMiddleware(app, force_port=force_port)
else:
@ -69,19 +65,17 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
return prefixed_app
def http_method_tween_factory(
handler: Callable,
registry: Registry,
) -> Callable:
def http_method_tween_factory(handler: Callable, registry: Registry) -> Callable:
# pylint: disable=unused-argument
"""Return a tween function that can override the request's HTTP method."""
def method_override_tween(request: Request) -> Request:
"""Override HTTP method with one specified in header."""
valid_overrides_by_method = {'POST': ['DELETE', 'PATCH', 'PUT']}
valid_overrides_by_method = {"POST": ["DELETE", "PATCH", "PUT"]}
original_method = request.method.upper()
valid_overrides = valid_overrides_by_method.get(original_method, [])
override = request.headers.get('X-HTTP-Method-Override', '').upper()
override = request.headers.get("X-HTTP-Method-Override", "").upper()
if override in valid_overrides:
request.method = override
@ -93,13 +87,13 @@ def http_method_tween_factory(
def get_redis_connection(request: Request) -> StrictRedis:
"""Return a StrictRedis connection to the Redis server."""
socket = request.registry.settings['redis.unix_socket_path']
socket = request.registry.settings["redis.unix_socket_path"]
return StrictRedis(unix_socket_path=socket)
def is_safe_request_method(request: Request) -> bool:
"""Return whether the request method is "safe" (is GET or HEAD)."""
return request.method in {'GET', 'HEAD'}
return request.method in {"GET", "HEAD"}
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
@ -107,7 +101,7 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
try:
action = RATE_LIMITED_ACTIONS[action_name]
except KeyError:
raise ValueError('Invalid action name: %s' % action_name)
raise ValueError("Invalid action name: %s" % action_name)
action.redis = request.redis
@ -127,24 +121,21 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
def current_listing_base_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "base" url for the current listing route.
The "base" url represents the current listing, including any filtering
options (or the fact that filters are disabled).
The "base" url represents the current listing, including any filtering options (or
the fact that filters are disabled).
The `query` argument allows adding query variables to the generated url.
"""
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
base_view_vars = (
'order', 'period', 'per_page', 'tag', 'type', 'unfiltered')
base_view_vars = ("order", "period", "per_page", "tag", "type", "unfiltered")
query_vars = {
key: val for key, val in request.GET.copy().items()
if key in base_view_vars
key: val for key, val in request.GET.copy().items() if key in base_view_vars
}
if query:
query_vars.update(query)
@ -152,12 +143,11 @@ def current_listing_base_url(
url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")
def current_listing_normal_url(
request: Request,
query: Optional[Dict[str, Any]] = None,
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "normal" url for the current listing route.
@ -166,13 +156,12 @@ def current_listing_normal_url(
The `query` argument allows adding query variables to the generated url.
"""
if request.matched_route.name not in ('home', 'group', 'user'):
raise AttributeError('Current route is not supported.')
if request.matched_route.name not in ("home", "group", "user"):
raise AttributeError("Current route is not supported.")
normal_view_vars = ('order', 'period', 'per_page')
normal_view_vars = ("order", "period", "per_page")
query_vars = {
key: val for key, val in request.GET.copy().items()
if key in normal_view_vars
key: val for key, val in request.GET.copy().items() if key in normal_view_vars
}
if query:
query_vars.update(query)
@ -180,4 +169,4 @@ def current_listing_normal_url(
url = request.current_route_url(_query=query_vars)
# Pyramid seems to %-encode tilde characters unnecessarily, fix that
return url.replace('%7E', '~')
return url.replace("%7E", "~")

10
tildes/tildes/api.py

@ -9,8 +9,8 @@ import venusian
class APIv0(Service):
"""Service wrapper class for v0 of the API."""
name_prefix = 'apiv0_'
base_path = '/api/v0'
name_prefix = "apiv0_"
base_path = "/api/v0"
def __init__(self, name: str, path: str, **kwargs: Any) -> None:
"""Create a new service."""
@ -19,8 +19,8 @@ class APIv0(Service):
super().__init__(name=name, path=path, **kwargs)
# Service.__init__ does this setup to support config.scan(), but it
# doesn't seem to inherit properly, so it needs to be done again here
# Service.__init__ does this setup to support config.scan(), but it doesn't seem
# to inherit properly, so it needs to be done again here
def callback(context: Any, name: Any, obj: Any) -> None:
# pylint: disable=unused-argument
config = context.config.with_package(info.module) # noqa
@ -28,4 +28,4 @@ class APIv0(Service):
# TEMP: disable API until I can fix the private-fields issue
# config.add_cornice_service(self)
info = venusian.attach(self, callback, category='pyramid')
info = venusian.attach(self, callback, category="pyramid")

70
tildes/tildes/auth.py

@ -7,13 +7,7 @@ from pyramid.authorization import ACLAuthorizationPolicy
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPFound
from pyramid.request import Request
from pyramid.security import (
ACLDenied,
ACLPermitsResult,
Allow,
Authenticated,
Everyone,
)
from pyramid.security import ACLDenied, ACLPermitsResult, Allow, Authenticated, Everyone
from tildes.models.user import User
@ -21,13 +15,13 @@ from tildes.models.user import User
class DefaultRootFactory:
"""Default root factory to grant everyone 'view' permission by default.
Note that this will only be applied in cases where a view does not have a
factory specified at all (so request.context doesn't have a meaningful
value). Any classes that could be returned by a root factory must have
an __acl__ defined, they will not "fall back" to this one.
Note that this will only be applied in cases where a view does not have a factory
specified at all (so request.context doesn't have a meaningful value). Any classes
that could be returned by a root factory must have an __acl__ defined, they will not
"fall back" to this one.
"""
__acl__ = ((Allow, Everyone, 'view'),)
__acl__ = ((Allow, Everyone, "view"),)
def __init__(self, request: Request) -> None:
"""Root factory constructor - must take a request argument."""
@ -40,10 +34,7 @@ def get_authenticated_user(request: Request) -> Optional[User]:
if not user_id:
return None
query = (
request.query(User)
.filter_by(user_id=user_id)
)
query = request.query(User).filter_by(user_id=user_id)
return query.one_or_none()
@ -51,8 +42,8 @@ def get_authenticated_user(request: Request) -> Optional[User]:
def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
"""Return authorization principals for a user_id from the session.
This is a callback function needed by SessionAuthenticationPolicy. It
should return None if the user_id does not exist (such as a deleted user).
This is a callback function needed by SessionAuthenticationPolicy. It should return
None if the user_id does not exist (such as a deleted user).
"""
if not request.user:
return None
@ -60,15 +51,15 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
# if the user is banned, log them out - is there a better place to do this?
if request.user.is_banned:
request.session.invalidate()
raise HTTPFound('/')
raise HTTPFound("/")
if user_id != request.user.user_id:
raise AssertionError('auth_callback called with different user_id')
raise AssertionError("auth_callback called with different user_id")
principals = []
if request.user.is_admin:
principals.append('admin')
principals.append("admin")
return principals
@ -76,45 +67,43 @@ def auth_callback(user_id: int, request: Request) -> Optional[Sequence[str]]:
def includeme(config: Configurator) -> None:
"""Config updates related to authentication/authorization."""
# make all views require "view" permission unless specifically overridden
config.set_default_permission('view')
config.set_default_permission("view")
# replace the default root factory with a custom one to more easily support
# the default permission
# replace the default root factory with a custom one to more easily support the
# default permission
config.set_root_factory(DefaultRootFactory)
# Set the authorization policy to a custom one that always returns a
# "denied" result if the user isn't logged in. When overall site access is
# no longer being restricted, the AuthorizedOnlyPolicy class can just be
# replaced with the standard ACLAuthorizationPolicy
# Set the authorization policy to a custom one that always returns a "denied" result
# if the user isn't logged in. When overall site access is no longer being
# restricted, the AuthorizedOnlyPolicy class can just be replaced with the standard
# ACLAuthorizationPolicy
config.set_authorization_policy(AuthorizedOnlyPolicy())
config.set_authentication_policy(
SessionAuthenticationPolicy(callback=auth_callback))
SessionAuthenticationPolicy(callback=auth_callback)
)
# enable CSRF checking globally by default
config.set_default_csrf_options(require_csrf=True)
# make the logged-in User object available as request.user
config.add_request_method(get_authenticated_user, 'user', reify=True)
config.add_request_method(get_authenticated_user, "user", reify=True)
# add has_any_permission method for easily checking multiple permissions
config.add_request_method(has_any_permission, 'has_any_permission')
config.add_request_method(has_any_permission, "has_any_permission")
class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
"""ACLAuthorizationPolicy override that always denies logged-out users."""
def permits(
self,
context: Any,
principals: Sequence[Any],
permission: str,
self, context: Any, principals: Sequence[Any], permission: str
) -> ACLPermitsResult:
"""Deny logged-out users, otherwise pass up to normal policy."""
if Authenticated not in principals:
return ACLDenied(
'<authorized only>',
'<no ACLs checked yet>',
"<authorized only>",
"<no ACLs checked yet>",
permission,
principals,
context,
@ -124,12 +113,9 @@ class AuthorizedOnlyPolicy(ACLAuthorizationPolicy):
def has_any_permission(
request: Request,
permissions: Sequence[str],
context: Any,
request: Request, permissions: Sequence[str], context: Any
) -> bool:
"""Return whether the user has any of the permissions on the item."""
return any(
request.has_permission(permission, context)
for permission in permissions
request.has_permission(permission, context) for permission in permissions
)

47
tildes/tildes/database.py

@ -28,10 +28,7 @@ def obtain_lock(request: Request, lock_space: str, lock_value: int) -> None:
obtain_transaction_lock(request.db_session, lock_space, lock_value)
def query_factory(
request: Request,
model_cls: Type[DatabaseModel],
) -> ModelQuery:
def query_factory(request: Request, model_cls: Type[DatabaseModel]) -> ModelQuery:
"""Return a ModelQuery or subclass depending on model_cls specified."""
if model_cls == Comment:
return CommentQuery(request)
@ -46,8 +43,7 @@ def query_factory(
def get_tm_session(
session_factory: Callable,
transaction_manager: ThreadTransactionManager,
session_factory: Callable, transaction_manager: ThreadTransactionManager
) -> Session:
"""Return a db session being managed by the transaction manager."""
db_session = session_factory()
@ -60,40 +56,39 @@ def includeme(config: Configurator) -> None:
Currently adds:
* request.db_session - the db session for the current request, managed by
pyramid_tm.
* request.query() - a factory method that will return a ModelQuery or
subclass for querying the model class supplied. This will generally be
used generatively, similar to standard SQLALchemy session.query(...).
* request.obtain_lock() - obtains a transaction-level advisory lock from
PostgreSQL.
* request.db_session - db session for the current request, managed by pyramid_tm.
* request.query() - a factory method that will return a ModelQuery or subclass for
querying the model class supplied. This will generally be used generatively,
similar to standard SQLALchemy session.query(...).
* request.obtain_lock() - obtains a transaction-level advisory lock from PostgreSQL.
"""
settings = config.get_settings()
# Enable pyramid_tm's default_commit_veto behavior, which will abort the
# transaction if the response code starts with 4 or 5. The main benefit of
# this is to avoid aborting on exceptions that don't actually indicate a
# problem, such as a HTTPFound 302 redirect.
settings['tm.commit_veto'] = 'pyramid_tm.default_commit_veto'
# Enable pyramid_tm's default_commit_veto behavior, which will abort the transaction
# if the response code starts with 4 or 5. The main benefit of this is to avoid
# aborting on exceptions that don't actually indicate a problem, such as a HTTPFound
# 302 redirect.
settings["tm.commit_veto"] = "pyramid_tm.default_commit_veto"
config.include('pyramid_tm')
config.include("pyramid_tm")
# disable SQLAlchemy connection pooling since pgbouncer will handle it
settings['sqlalchemy.poolclass'] = NullPool
settings["sqlalchemy.poolclass"] = NullPool
engine = engine_from_config(settings, 'sqlalchemy.')
engine = engine_from_config(settings, "sqlalchemy.")
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
config.registry['db_session_factory'] = session_factory
config.registry["db_session_factory"] = session_factory
# attach the session to each request as request.db_session
config.add_request_method(
lambda request: get_tm_session(
config.registry['db_session_factory'], request.tm),
'db_session',
config.registry["db_session_factory"], request.tm
),
"db_session",
reify=True,
)
config.add_request_method(query_factory, 'query')
config.add_request_method(query_factory, "query")
config.add_request_method(obtain_lock, 'obtain_lock')
config.add_request_method(obtain_lock, "obtain_lock")

26
tildes/tildes/enums.py

@ -21,12 +21,12 @@ class CommentSortOption(enum.Enum):
@property
def description(self) -> str:
"""Describe this sort option."""
if self.name == 'NEWEST':
return 'newest first'
elif self.name == 'POSTED':
return 'order posted'
if self.name == "NEWEST":
return "newest first"
elif self.name == "POSTED":
return "order posted"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class CommentTagOption(enum.Enum):
@ -68,16 +68,16 @@ class TopicSortOption(enum.Enum):
def descending_description(self) -> str:
"""Describe this sort option when used in a "descending" order.
For example, the "votes" sort has a description of "most votes", since
using that sort in descending order means that topics with the most
votes will be listed first.
For example, the "votes" sort has a description of "most votes", since using
that sort in descending order means that topics with the most votes will be
listed first.
"""
if self.name == 'NEW':
return 'newest'
elif self.name == 'ACTIVITY':
return 'activity'
if self.name == "NEW":
return "newest"
elif self.name == "ACTIVITY":
return "activity"
return 'most {}'.format(self.name.lower()) # noqa
return "most {}".format(self.name.lower()) # noqa
class TopicType(enum.Enum):

26
tildes/tildes/jinja.py

@ -29,28 +29,26 @@ def includeme(config: Configurator) -> None:
"""Configure Jinja2 template renderer."""
settings = config.get_settings()
settings['jinja2.lstrip_blocks'] = True
settings['jinja2.trim_blocks'] = True
settings['jinja2.undefined'] = 'strict'
settings["jinja2.lstrip_blocks"] = True
settings["jinja2.trim_blocks"] = True
settings["jinja2.undefined"] = "strict"
# add custom jinja filters
settings['jinja2.filters'] = {
'ago': descriptive_timedelta,
}
settings["jinja2.filters"] = {"ago": descriptive_timedelta}
# add custom jinja tests
settings['jinja2.tests'] = {
'comment': is_comment,
'group': is_group,
'topic': is_topic,
settings["jinja2.tests"] = {
"comment": is_comment,
"group": is_group,
"topic": is_topic,
}
config.include('pyramid_jinja2')
config.include("pyramid_jinja2")
config.add_jinja2_search_path('tildes:templates/')
config.add_jinja2_search_path("tildes:templates/")
config.add_jinja2_extension('jinja2.ext.do')
config.add_jinja2_extension('webassets.ext.jinja2.AssetsExtension')
config.add_jinja2_extension("jinja2.ext.do")
config.add_jinja2_extension("webassets.ext.jinja2.AssetsExtension")
# attach webassets to jinja2 environment (via scheduled action)
def attach_webassets_to_jinja2() -> None:

10
tildes/tildes/json.py

@ -13,8 +13,8 @@ from tildes.models.user import User
def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
"""Return serializable data for a DatabaseModel item.
Uses the .schema class attribute to serialize a model by using its
corresponding marshmallow schema.
Uses the .schema class attribute to serialize a model by using its corresponding
marshmallow schema.
"""
# pylint: disable=unused-argument
return model_item.schema.dump(model_item)
@ -23,8 +23,8 @@ def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
def serialize_topic(topic: Topic, request: Request) -> dict:
"""Return serializable data for a Topic."""
context = {}
if not request.has_permission('view_author', topic):
context['hide_username'] = True
if not request.has_permission("view_author", topic):
context["hide_username"] = True
return topic.schema_class(context=context).dump(topic)
@ -40,4 +40,4 @@ def includeme(config: Configurator) -> None:
# add specific adapters
json_renderer.add_adapter(Topic, serialize_topic)
config.add_renderer('json', json_renderer)
config.add_renderer("json", json_renderer)

9
tildes/tildes/lib/__init__.py

@ -1,9 +1,8 @@
"""Contains the overall "library" for the application.
Defining constants, behavior, etc. inside modules here (as opposed to other
locations such as in models) is encouraged, since it often makes it simpler to
import elsewhere for tests, when only a specific constant value is needed, etc.
Defining constants, behavior, etc. inside modules here (as opposed to other locations
such as in models) is encouraged, since it often makes it simpler to import elsewhere
for tests, when only a specific constant value is needed, etc.
Modules here should *never* import anything from models, to avoid circular
dependencies.
Modules here should *never* import anything from models, to avoid circular dependencies.
"""

33
tildes/tildes/lib/amqp.py

@ -13,41 +13,34 @@ from tildes.lib.database import get_session_from_config
class PgsqlQueueConsumer(AbstractConsumer):
"""Base class for consumers of events sent from PostgreSQL via rabbitmq.
This class is intended to be used in a completely "stand-alone" manner,
such as inside a script being run separately as a background job. As such,
it also includes connecting to rabbitmq, declaring the underlying queue and
bindings, and (optionally) connecting to the database to be able to fetch
and modify data as necessary. It relies on the environment variable
INI_FILE being set.
Note that all messages received by these consumers are expected to be in
JSON format.
This class is intended to be used in a completely "stand-alone" manner, such as
inside a script being run separately as a background job. As such, it also includes
connecting to rabbitmq, declaring the underlying queue and bindings, and
(optionally) connecting to the database to be able to fetch and modify data as
necessary. It relies on the environment variable INI_FILE being set.
Note that all messages received by these consumers are expected to be in JSON
format.
"""
PGSQL_EXCHANGE_NAME = 'pgsql_events'
PGSQL_EXCHANGE_NAME = "pgsql_events"
def __init__(
self,
queue_name: str,
routing_keys: Sequence[str],
uses_db: bool = True,
self, queue_name: str, routing_keys: Sequence[str], uses_db: bool = True
) -> None:
"""Initialize a new queue, bindings, and consumer for it."""
self.connection = Connection()
self.channel = self.connection.channel()
self.channel.queue_declare(
queue_name, durable=True, auto_delete=False)
self.channel.queue_declare(queue_name, durable=True, auto_delete=False)
for routing_key in routing_keys:
self.channel.queue_bind(
queue_name,
exchange=self.PGSQL_EXCHANGE_NAME,
routing_key=routing_key,
queue_name, exchange=self.PGSQL_EXCHANGE_NAME, routing_key=routing_key
)
if uses_db:
self.db_session = get_session_from_config(os.environ['INI_FILE'])
self.db_session = get_session_from_config(os.environ["INI_FILE"])
else:
self.db_session = None

12
tildes/tildes/lib/cmark.py

@ -4,14 +4,14 @@
from ctypes import CDLL, c_char_p, c_int, c_size_t, c_void_p
CMARK_DLL = CDLL('/usr/local/lib/libcmark-gfm.so')
CMARK_EXT_DLL = CDLL('/usr/local/lib/libcmark-gfmextensions.so')
CMARK_DLL = CDLL("/usr/local/lib/libcmark-gfm.so")
CMARK_EXT_DLL = CDLL("/usr/local/lib/libcmark-gfmextensions.so")
# enables the --hardbreaks option for cmark
# (can I import this? it's defined in cmark.h as CMARK_OPT_HARDBREAKS)
CMARK_OPTS = 4
CMARK_EXTENSIONS = (b'strikethrough', b'table')
CMARK_EXTENSIONS = (b"strikethrough", b"table")
cmark_parser_new = CMARK_DLL.cmark_parser_new
cmark_parser_new.restype = c_void_p
@ -25,13 +25,11 @@ cmark_parser_finish = CMARK_DLL.cmark_parser_finish
cmark_parser_finish.restype = c_void_p
cmark_parser_finish.argtypes = (c_void_p,)
cmark_parser_attach_syntax_extension = (
CMARK_DLL.cmark_parser_attach_syntax_extension)
cmark_parser_attach_syntax_extension = CMARK_DLL.cmark_parser_attach_syntax_extension
cmark_parser_attach_syntax_extension.restype = c_int
cmark_parser_attach_syntax_extension.argtypes = (c_void_p, c_void_p)
cmark_parser_get_syntax_extensions = (
CMARK_DLL.cmark_parser_get_syntax_extensions)
cmark_parser_get_syntax_extensions = CMARK_DLL.cmark_parser_get_syntax_extensions
cmark_parser_get_syntax_extensions.restype = c_void_p
cmark_parser_get_syntax_extensions.argtypes = (c_void_p,)

36
tildes/tildes/lib/database.py

@ -20,7 +20,7 @@ NOT_NULL_ERROR_CODE = 23502
def get_session_from_config(config_path: str) -> Session:
"""Get a database session from a config file (specified by path)."""
env = bootstrap(config_path)
session_factory = env['registry']['db_session_factory']
session_factory = env["registry"]["db_session_factory"]
return session_factory()
@ -31,25 +31,21 @@ class LockSpaces(enum.Enum):
def obtain_transaction_lock(
session: Session,
lock_space: Optional[str],
lock_value: int,
session: Session, lock_space: Optional[str], lock_value: int
) -> None:
"""Obtain a transaction-level advisory lock from PostgreSQL.
The lock_space arg must be either None or the name of one of the members of
the LockSpaces enum (case-insensitive). Contention for a lock will only
occur when both lock_space and lock_value have the same values.
The lock_space arg must be either None or the name of one of the members of the
LockSpaces enum (case-insensitive). Contention for a lock will only occur when both
lock_space and lock_value have the same values.
"""
if lock_space:
try:
lock_space_value = LockSpaces[lock_space.upper()].value
except KeyError:
raise ValueError('Invalid lock space: %s' % lock_space)
raise ValueError("Invalid lock space: %s" % lock_space)
session.query(
func.pg_advisory_xact_lock(lock_space_value, lock_value)
).one()
session.query(func.pg_advisory_xact_lock(lock_space_value, lock_value)).one()
else:
session.query(func.pg_advisory_xact_lock(lock_value)).one()
@ -66,10 +62,11 @@ class CIText(UserDefinedType):
def get_col_spec(self, **kw: Any) -> str:
"""Return the type name (for creating columns and so on)."""
# pylint: disable=no-self-use,unused-argument
return 'CITEXT'
return "CITEXT"
def bind_processor(self, dialect: Dialect) -> Callable:
"""Return a conversion function for processing bind values."""
def process(value: Any) -> Any:
return value
@ -77,6 +74,7 @@ class CIText(UserDefinedType):
def result_processor(self, dialect: Dialect, coltype: Any) -> Callable:
"""Return a conversion function for processing result row values."""
def process(value: Any) -> Any:
return value
@ -103,8 +101,8 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
super_rp = super().result_processor(dialect, coltype)
def handle_raw_string(value: str) -> List[str]:
if not (value.startswith('{') and value.endswith('}')):
raise ValueError('%s is not an array value' % value)
if not (value.startswith("{") and value.endswith("}")):
raise ValueError("%s is not an array value" % value)
# trim off the surrounding braces
value = value[1:-1]
@ -113,7 +111,7 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
if not value:
return []
return value.split(',')
return value.split(",")
def process(value: Optional[str]) -> Optional[List[str]]:
if value is None:
@ -127,14 +125,14 @@ class ArrayOfLtree(ARRAY): # pylint: disable=too-many-ancestors
class comparator_factory(ARRAY.comparator_factory):
"""Add custom comparison functions.
The ancestor_of and descendant_of functions are supported by LtreeType,
so this duplicates them here so they can be used on ArrayOfLtree too.
The ancestor_of and descendant_of functions are supported by LtreeType, so this
duplicates them here so they can be used on ArrayOfLtree too.
"""
def ancestor_of(self, other): # type: ignore
"""Return whether the array contains any ancestor of `other`."""
return self.op('@>')(other)
return self.op("@>")(other)
def descendant_of(self, other): # type: ignore
"""Return whether the array contains any descendant of `other`."""
return self.op('<@')(other)
return self.op("<@")(other)

52
tildes/tildes/lib/datetime.py

@ -10,32 +10,32 @@ from ago import human
class SimpleHoursPeriod:
"""A simple class that represents a time period of hours or days."""
_SHORT_FORM_REGEX = re.compile(r'\d+[hd]', re.IGNORECASE)
_SHORT_FORM_REGEX = re.compile(r"\d+[hd]", re.IGNORECASE)
def __init__(self, hours: int) -> None:
"""Initialize a SimpleHoursPeriod from a number of hours."""
if hours <= 0:
raise ValueError('Period must be at least 1 hour.')
raise ValueError("Period must be at least 1 hour.")
self.hours = hours
try:
self.timedelta = timedelta(hours=hours)
except OverflowError:
raise ValueError('Time period is too large')
raise ValueError("Time period is too large")
@classmethod
def from_short_form(cls, short_form: str) -> 'SimpleHoursPeriod':
def from_short_form(cls, short_form: str) -> "SimpleHoursPeriod":
"""Initialize a period from a "short form" string (e.g. "2h", "4d")."""
if not cls._SHORT_FORM_REGEX.match(short_form):
raise ValueError('Invalid time period')
raise ValueError("Invalid time period")
unit = short_form[-1].lower()
count = int(short_form[:-1])
if unit == 'h':
if unit == "h":
hours = count
elif unit == 'd':
elif unit == "d":
hours = count * 24
return cls(hours=hours)
@ -43,13 +43,12 @@ class SimpleHoursPeriod:
def __str__(self) -> str:
"""Return a representation of the period as a string.
Will be of the form "4 hours", "2 days", "1 day, 6 hours", etc. except
for the special case of exactly "1 day", which is replaced with "24
hours".
Will be of the form "4 hours", "2 days", "1 day, 6 hours", etc. except for the
special case of exactly "1 day", which is replaced with "24 hours".
"""
string = human(self.timedelta, past_tense='{}')
if string == '1 day':
string = '24 hours'
string = human(self.timedelta, past_tense="{}")
if string == "1 day":
string = "24 hours"
return string
@ -63,13 +62,13 @@ class SimpleHoursPeriod:
def as_short_form(self) -> str:
"""Return a representation of the period as a "short form" string.
Uses "hours" representation unless the period is an exact multiple of
24 hours (except for 24 hours itself).
Uses "hours" representation unless the period is an exact multiple of 24 hours
(except for 24 hours itself).
"""
if self.hours % 24 == 0 and self.hours != 24:
return '{}d'.format(self.hours // 24)
return "{}d".format(self.hours // 24)
return f'{self.hours}h'
return f"{self.hours}h"
def utc_now() -> datetime:
@ -80,20 +79,19 @@ def utc_now() -> datetime:
def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
"""Return a descriptive string for how long ago a datetime was.
The returned string will be of a format like "4 hours ago" or
"3 hours, 21 minutes ago". The second "precision level" is only added if
it will be at least minutes, and only one "level" below the first unit.
That is, you'd never see anything like "4 hours, 5 seconds ago" or
"2 years, 3 hours ago".
The returned string will be of a format like "4 hours ago" or "3 hours, 21 minutes
ago". The second "precision level" is only added if it will be at least minutes, and
only one "level" below the first unit. That is, you'd never see anything like "4
hours, 5 seconds ago" or "2 years, 3 hours ago".
If `abbreviate` is true, the units will be shortened to return a string
like "12h 28m ago" instead of "12 hours, 28 minutes ago".
If `abbreviate` is true, the units will be shortened to return a string like
"12h 28m ago" instead of "12 hours, 28 minutes ago".
A time of less than a second returns "a moment ago".
"""
seconds_ago = (utc_now() - target).total_seconds()
if seconds_ago < 1:
return 'a moment ago'
return "a moment ago"
# determine whether one or two precision levels is appropriate
if seconds_ago < 3600:
@ -103,7 +101,7 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# try a precision=2 version, and check the units it ends up with
result = human(target, precision=2)
units = ('year', 'day', 'hour', 'minute', 'second')
units = ("year", "day", "hour", "minute", "second")
unit_indices = [i for (i, unit) in enumerate(units) if unit in result]
# if there was only one unit in it, or they're adjacent, this is fine
@ -117,6 +115,6 @@ def descriptive_timedelta(target: datetime, abbreviate: bool = False) -> str:
# remove commas if abbreviating ("3d 2h ago", not "3d, 2h ago")
if abbreviate:
result = result.replace(',', '')
result = result.replace(",", "")
return result

11
tildes/tildes/lib/hash.py

@ -3,15 +3,16 @@
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
# These parameter values were chosen to achieve a hash-verification time of
# about 10ms on the current production server. They can be updated to different
# values if the server changes (consider upgrading old password hashes on login
# as well if that happens).
# These parameter values were chosen to achieve a hash-verification time of about 10ms
# on the current production server. They can be updated to different values if the
# server changes (consider upgrading old password hashes on login as well if that
# happens).
ARGON2_TIME_COST = 4
ARGON2_MEMORY_COST = 8092
ARGON2_HASHER = PasswordHasher(
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST)
time_cost=ARGON2_TIME_COST, memory_cost=ARGON2_MEMORY_COST
)
def hash_string(string: str) -> str:

19
tildes/tildes/lib/id.py

@ -4,24 +4,23 @@ import re
import string
ID36_REGEX = re.compile('^[a-z0-9]+$', re.IGNORECASE)
ID36_REGEX = re.compile("^[a-z0-9]+$", re.IGNORECASE)
def id_to_id36(id_val: int) -> str:
"""Convert an integer ID to the string ID36 representation."""
if id_val < 1:
raise ValueError('ID values should never be zero or negative')
raise ValueError("ID values should never be zero or negative")
reversed_chars = []
# the "alphabet" of our ID36s - 0-9 followed by a-z
alphabet = string.digits + string.ascii_lowercase
# Repeatedly use divmod() on the value, which returns the quotient and
# remainder of each integer division - divmod(a, b) == (a // b, a % b).
# The remainder of each division works as an index into the alphabet, and
# doing this repeatedly will build up our ID36 string in reverse order
# (with the least-significant "digit" first).
# Repeatedly use divmod() on the value, which returns the quotient and remainder of
# each integer division - divmod(a, b) == (a // b, a % b). The remainder of each
# division works as an index into the alphabet, and doing this repeatedly will build
# up our ID36 string in reverse order (with the least-significant "digit" first).
quotient, index = divmod(id_val, 36)
while quotient != 0:
reversed_chars.append(alphabet[index])
@ -29,13 +28,13 @@ def id_to_id36(id_val: int) -> str:
reversed_chars.append(alphabet[index])
# join the characters in reversed order and return as the result
return ''.join(reversed(reversed_chars))
return "".join(reversed(reversed_chars))
def id36_to_id(id36_val: str) -> int:
"""Convert a string ID36 to the integer ID representation."""
if id36_val.startswith('-') or id36_val == '0':
raise ValueError('ID values should never be zero or negative')
if id36_val.startswith("-") or id36_val == "0":
raise ValueError("ID values should never be zero or negative")
# Python's stdlib can handle this, much simpler in this direction
return int(id36_val, 36)

303
tildes/tildes/lib/markdown.py

@ -40,67 +40,66 @@ from .cmark import (
HTML_TAG_WHITELIST = (
'a',
'b',
'blockquote',
'br',
'code',
'del',
'em',
'h1',
'h2',
'h3',
'h4',
'h5',
'h6',
'hr',
'i',
'ins',
'li',
'ol',
'p',
'pre',
'strong',
'sub',
'sup',
'table',
'tbody',
'td',
'th',
'thead',
'tr',
'ul',
"a",
"b",
"blockquote",
"br",
"code",
"del",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"hr",
"i",
"ins",
"li",
"ol",
"p",
"pre",
"strong",
"sub",
"sup",
"table",
"tbody",
"td",
"th",
"thead",
"tr",
"ul",
)
HTML_ATTRIBUTE_WHITELIST = {
'a': ['href', 'title'],
'ol': ['start'],
'td': ['align'],
'th': ['align'],
"a": ["href", "title"],
"ol": ["start"],
"td": ["align"],
"th": ["align"],
}
PROTOCOL_WHITELIST = ('http', 'https')
PROTOCOL_WHITELIST = ("http", "https")
# Regex that finds ordered list markdown that was probably accidental - ones
# being initiated by anything except "1."
# Regex that finds ordered list markdown that was probably accidental - ones being
# initiated by anything except "1."
BAD_ORDERED_LIST_REGEX = re.compile(
r'((?:\A|\n\n)' # Either the start of the entire text, or a new paragraph
r'(?!1\.)\d+)' # A number that isn't "1"
r'\.\s', # Followed by a period and a space
r"((?:\A|\n\n)" # Either the start of the entire text, or a new paragraph
r"(?!1\.)\d+)" # A number that isn't "1"
r"\.\s" # Followed by a period and a space
)
# Type alias for the "namespaced attr dict" used inside bleach.linkify
# callbacks. This looks pretty ridiculous, but it's a dict where the keys are
# namespaced attr names, like `(None, 'href')`, and there's also a `_text`
# key for getting the innerText of the <a> tag.
# Type alias for the "namespaced attr dict" used inside bleach.linkify callbacks. This
# looks pretty ridiculous, but it's a dict where the keys are namespaced attr names,
# like `(None, 'href')`, and there's also a `_text` key for getting the innerText of the
# <a> tag.
NamespacedAttrDict = Dict[Union[Tuple[Optional[str], str], str], str] # noqa
def linkify_protocol_whitelist(
attrs: NamespacedAttrDict,
new: bool = False,
attrs: NamespacedAttrDict, new: bool = False
) -> Optional[NamespacedAttrDict]:
"""bleach.linkify callback: prevent links to non-whitelisted protocols."""
# pylint: disable=unused-argument
href = attrs.get((None, 'href'))
href = attrs.get((None, "href"))
if not href:
return attrs
@ -112,13 +111,13 @@ def linkify_protocol_whitelist(
return attrs
@histogram_timer('markdown_processing')
@histogram_timer("markdown_processing")
def convert_markdown_to_safe_html(markdown: str) -> str:
"""Convert markdown to sanitized HTML."""
# apply custom pre-processing to markdown
markdown = preprocess_markdown(markdown)
markdown_bytes = markdown.encode('utf8')
markdown_bytes = markdown.encode("utf8")
parser = cmark_parser_new(CMARK_OPTS)
for name in CMARK_EXTENSIONS:
@ -134,7 +133,7 @@ def convert_markdown_to_safe_html(markdown: str) -> str:
cmark_parser_free(parser)
cmark_node_free(doc)
html = html_bytes.decode('utf8')
html = html_bytes.decode("utf8")
# apply custom post-processing to HTML
html = postprocess_markdown_html(html)
@ -148,7 +147,7 @@ def preprocess_markdown(markdown: str) -> str:
markdown = escape_accidental_ordered_lists(markdown)
# fix the "shrug" emoji ¯\_(ツ)_/¯ to prevent markdown mangling it
markdown = markdown.replace(r'¯\_(ツ)_/¯', r'¯\\\_(ツ)\_/¯')
markdown = markdown.replace(r"¯\_(ツ)_/¯", r"¯\\\_(ツ)\_/¯")
return markdown
@ -156,29 +155,26 @@ def preprocess_markdown(markdown: str) -> str:
def escape_accidental_ordered_lists(markdown: str) -> str:
"""Escape markdown that's probably an accidental ordered list.
It's a common markdown mistake to accidentally start a numbered list, by
beginning a post or paragraph with a number followed by a period. For
example, someone might try to write "1975. It was a long time ago.", and
the result will be a comment that says "1. It was a long time ago." since
that gets parsed into a numbered list.
It's a common markdown mistake to accidentally start a numbered list, by beginning a
post or paragraph with a number followed by a period. For example, someone might try
to write "1975. It was a long time ago.", and the result will be a comment that says
"1. It was a long time ago." since that gets parsed into a numbered list.
This fixes that quirk of markdown by escaping anything that would start a
numbered list except for "1. ". This will cause a few other edge cases, but
I believe they're less common/important than fixing this common error.
This fixes that quirk of markdown by escaping anything that would start a numbered
list except for "1. ". This will cause a few other edge cases, but I believe they're
less common/important than fixing this common error.
"""
return BAD_ORDERED_LIST_REGEX.sub(r'\1\\. ', markdown)
return BAD_ORDERED_LIST_REGEX.sub(r"\1\\. ", markdown)
def postprocess_markdown_html(html: str) -> str:
"""Apply post-processing to HTML generated by markdown parser."""
# list of tag names to exclude from linkification
linkify_skipped_tags = ['pre']
linkify_skipped_tags = ["pre"]
# search for text that looks like urls and convert to actual links
html = bleach.linkify(
html,
callbacks=[linkify_protocol_whitelist],
skip_tags=linkify_skipped_tags,
html, callbacks=[linkify_protocol_whitelist], skip_tags=linkify_skipped_tags
)
# run the HTML through our custom linkification process as well
@ -187,20 +183,17 @@ def postprocess_markdown_html(html: str) -> str:
return html
def apply_linkification(
html: str,
skip_tags: Optional[List[str]] = None,
) -> str:
def apply_linkification(html: str, skip_tags: Optional[List[str]] = None) -> str:
"""Apply custom linkification filter to convert text patterns to links."""
parser = HTMLParser(namespaceHTMLElements=False)
html_tree = parser.parseFragment(html)
walker_stream = html5lib.getTreeWalker('etree')(html_tree)
walker_stream = html5lib.getTreeWalker("etree")(html_tree)
filtered_html_tree = LinkifyFilter(walker_stream, skip_tags)
serializer = HTMLSerializer(
quote_attr_values='always',
quote_attr_values="always",
omit_optional_tags=False,
sanitize=False,
alphabetical_attributes=False,
@ -211,71 +204,70 @@ def apply_linkification(
class LinkifyFilter(Filter):
"""html5lib Filter to convert custom text patterns to links.
This replaces references to group paths and usernames with links to the
relevant pages.
This replaces references to group paths and usernames with links to the relevant
pages.
This implementation is based heavily on the linkify implementation from
the Bleach library.
This implementation is based heavily on the linkify implementation from the Bleach
library.
"""
# Regex that finds probable references to groups. This isn't "perfect",
# just a first pass to find likely candidates. The validity of the group
# path is checked more carefully later.
# Note: currently specifically excludes paths immediately followed by a
# tilde, but this may be possible to remove once strikethrough is
# implemented (since that's probably what they were trying to do)
GROUP_REFERENCE_REGEX = re.compile(r'(?<!\w)~([\w.]+)\b(?!~)')
# Regex that finds probable references to groups. This isn't "perfect", just a first
# pass to find likely candidates. The validity of the group path is checked more
# carefully later.
# Note: currently specifically excludes paths immediately followed by a tilde, but
# this may be possible to remove once strikethrough is implemented (since that's
# probably what they were trying to do)
GROUP_REFERENCE_REGEX = re.compile(r"(?<!\w)~([\w.]+)\b(?!~)")
# Regex that finds probable references to users. As above, this isn't
# "perfect" either but works as an initial pass with the validity of
# the username checked more carefully later.
USERNAME_REFERENCE_REGEX = re.compile(r'(?<!\w)(?:/?u/|@)([\w-]+)\b')
# Regex that finds probable references to users. As above, this isn't "perfect"
# either but works as an initial pass with the validity of the username checked more
# carefully later.
USERNAME_REFERENCE_REGEX = re.compile(r"(?<!\w)(?:/?u/|@)([\w-]+)\b")
def __init__(
self,
source: NonRecursiveTreeWalker,
skip_tags: Optional[List[str]] = None,
self, source: NonRecursiveTreeWalker, skip_tags: Optional[List[str]] = None
) -> None:
"""Initialize a linkification filter to apply to HTML.
The skip_tags argument can be a list of tag names, and the contents of
any of those tags will be excluded from linkification.
The skip_tags argument can be a list of tag names, and the contents of any of
those tags will be excluded from linkification.
"""
super().__init__(source)
self.skip_tags = skip_tags or []
# always skip the contents of <a> tags in addition to any others
self.skip_tags.append('a')
self.skip_tags.append("a")
def __iter__(self) -> Iterator[dict]:
"""Iterate over the tree, modifying it as necessary before yielding."""
inside_skipped_tags = []
for token in super().__iter__():
if (token['type'] in ('StartTag', 'EmptyTag') and
token['name'] in self.skip_tags):
# if this is the start of a tag we want to skip, add it to the
# list of skipped tags that we're currently inside
inside_skipped_tags.append(token['name'])
if (
token["type"] in ("StartTag", "EmptyTag")
and token["name"] in self.skip_tags
):
# if this is the start of a tag we want to skip, add it to the list of
# skipped tags that we're currently inside
inside_skipped_tags.append(token["name"])
elif inside_skipped_tags:
# if we're currently inside any skipped tags, the only thing we
# want to do is look for all the end tags we need to be able to
# finish skipping
if token['type'] == 'EndTag':
# if we're currently inside any skipped tags, the only thing we want to
# do is look for all the end tags we need to be able to finish skipping
if token["type"] == "EndTag":
try:
inside_skipped_tags.remove(token['name'])
inside_skipped_tags.remove(token["name"])
except ValueError:
pass
elif token['type'] == 'Characters':
# this is only reachable if inside_skipped_tags is empty, so
# this is a text token not inside a skipped tag - do the actual
# linkification replacements
# Note: doing the two replacements "iteratively" like this only
# works because they are "disjoint" and we know they're not
# competing to replace the same text. If more replacements are
# added in the future that might conflict with each other, this
# will need to be reworked somehow.
elif token["type"] == "Characters":
# this is only reachable if inside_skipped_tags is empty, so this is a
# text token not inside a skipped tag - do the actual linkification
# replacements
# Note: doing the two replacements "iteratively" like this only works
# because they are "disjoint" and we know they're not competing to
# replace the same text. If more replacements are added in the future
# that might conflict with each other, this will need to be reworked
# somehow.
replaced_tokens = self._linkify_tokens(
[token],
filter_regex=self.GROUP_REFERENCE_REGEX,
@ -287,50 +279,50 @@ class LinkifyFilter(Filter):
linkify_function=self._tokenize_username_match,
)
# yield all the tokens returned from the replacement process
# (will be just the original token if nothing was replaced)
# yield all the tokens returned from the replacement process (will be
# just the original token if nothing was replaced)
for new_token in replaced_tokens:
yield new_token
# we either yielded new tokens or the original one already, so
# we don't want to fall through and yield the original again
# we either yielded new tokens or the original one already, so we don't
# want to fall through and yield the original again
continue
yield token
@staticmethod
def _linkify_tokens(
tokens: List[dict],
filter_regex: Pattern,
linkify_function: Callable,
tokens: List[dict], filter_regex: Pattern, linkify_function: Callable
) -> List[dict]:
"""Check tokens for text that matches a regex and linkify it.
The `filter_regex` argument should be a compiled pattern that will be
applied to the text in all of the supplied tokens. If any matches are
found, they will each be used to call `linkify_function`, which will
validate the match and convert it back into tokens (representing an <a>
tag if it is valid for linkifying, or just text if not).
The `filter_regex` argument should be a compiled pattern that will be applied to
the text in all of the supplied tokens. If any matches are found, they will each
be used to call `linkify_function`, which will validate the match and convert it
back into tokens (representing an <a> tag if it is valid for linkifying, or just
text if not).
"""
new_tokens = []
for token in tokens:
# we don't want to touch any tokens other than character ones
if token['type'] != 'Characters':
if token["type"] != "Characters":
new_tokens.append(token)
continue
original_text = token['data']
original_text = token["data"]
current_index = 0
for match in filter_regex.finditer(original_text):
# if there were some characters between the previous match and
# this one, add a token containing those first
# if there were some characters between the previous match and this one,
# add a token containing those first
if match.start() > current_index:
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:match.start()],
})
new_tokens.append(
{
"type": "Characters",
"data": original_text[current_index : match.start()],
}
)
# call the linkify function to convert this match into tokens
linkified_tokens = linkify_function(match)
@ -339,43 +331,42 @@ class LinkifyFilter(Filter):
# move the progress marker up to the end of this match
current_index = match.end()
# if there's still some text left over, add one more token for it
# (this will be the entire thing if there weren't any matches)
# if there's still some text left over, add one more token for it (this will
# be the entire thing if there weren't any matches)
if current_index < len(original_text):
new_tokens.append({
'type': 'Characters',
'data': original_text[current_index:],
})
new_tokens.append(
{"type": "Characters", "data": original_text[current_index:]}
)
return new_tokens
@staticmethod
def _tokenize_group_match(match: Match) -> List[dict]:
"""Convert a potential group reference into HTML tokens."""
# convert the potential group path to lowercase to allow people to use
# incorrect casing but still have it link properly
# convert the potential group path to lowercase to allow people to use incorrect
# casing but still have it link properly
group_path = match[1].lower()
# Even though they're technically valid paths, we don't want to linkify
# things like "~10" or "~4.5" since that's just going to be someone
# using it in the "approximately" sense. So if the path consists of
# only numbers and/or periods, we won't linkify it
is_numeric = all(char in '0123456789.' for char in group_path)
# Even though they're technically valid paths, we don't want to linkify things
# like "~10" or "~4.5" since that's just going to be someone using it in the
# "approximately" sense. So if the path consists of only numbers and/or periods,
# we won't linkify it
is_numeric = all(char in "0123456789." for char in group_path)
# if it's a valid group path and not totally numeric, convert to <a>
if is_valid_group_path(group_path) and not is_numeric:
return [
{
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/~{group_path}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/~{group_path}"},
},
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
]
# one of the checks failed, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
@staticmethod
def _tokenize_username_match(match: Match) -> List[dict]:
@ -384,16 +375,16 @@ class LinkifyFilter(Filter):
if is_valid_username(match[1]):
return [
{
'type': 'StartTag',
'name': 'a',
'data': {(None, 'href'): f'/user/{match[1]}'},
"type": "StartTag",
"name": "a",
"data": {(None, "href"): f"/user/{match[1]}"},
},
{'type': 'Characters', 'data': match[0]},
{'type': 'EndTag', 'name': 'a'},
{"type": "Characters", "data": match[0]},
{"type": "EndTag", "name": "a"},
]
# the username wasn't valid, so just keep it as the original text
return [{'type': 'Characters', 'data': match[0]}]
return [{"type": "Characters", "data": match[0]}]
def sanitize_html(html: str) -> str:

2
tildes/tildes/lib/message.py

@ -1,6 +1,6 @@
"""Functions/constants related to messages."""
WELCOME_MESSAGE_SUBJECT = 'Welcome to the Tildes alpha'
WELCOME_MESSAGE_SUBJECT = "Welcome to the Tildes alpha"
# pylama:ignore=E501
WELCOME_MESSAGE_TEXT = """

15
tildes/tildes/lib/password.py

@ -5,22 +5,23 @@ from hashlib import sha1
from redis import ConnectionError, ResponseError, StrictRedis # noqa
# unix socket path for redis server with the breached passwords bloom filter
BREACHED_PASSWORDS_REDIS_SOCKET = '/run/redis_breached_passwords/socket'
BREACHED_PASSWORDS_REDIS_SOCKET = "/run/redis_breached_passwords/socket"
# Key where the bloom filter of password hashes from data breaches is stored
BREACHED_PASSWORDS_BF_KEY = 'breached_passwords_bloom'
BREACHED_PASSWORDS_BF_KEY = "breached_passwords_bloom"
def is_breached_password(password: str) -> bool:
"""Return whether the password is in the breached-passwords list."""
redis = StrictRedis(unix_socket_path=BREACHED_PASSWORDS_REDIS_SOCKET)
hashed = sha1(password.encode('utf-8')).hexdigest()
hashed = sha1(password.encode("utf-8")).hexdigest()
try:
return bool(redis.execute_command(
'BF.EXISTS', BREACHED_PASSWORDS_BF_KEY, hashed))
return bool(
redis.execute_command("BF.EXISTS", BREACHED_PASSWORDS_BF_KEY, hashed)
)
except (ConnectionError, ResponseError):
# server isn't running, bloom filter doesn't exist or the key is a
# different data type
# server isn't running, bloom filter doesn't exist or the key is a different
# data type
return False

144
tildes/tildes/lib/ratelimit.py

@ -19,24 +19,22 @@ class RateLimitError(Exception):
class RateLimitResult:
"""The result from a rate-limit check.
Includes data relating to whether the action should be allowed or blocked,
how much of the limit is remaining, how long until the action can be
retried, etc.
Includes data relating to whether the action should be allowed or blocked, how much
of the limit is remaining, how long until the action can be retried, etc.
"""
def __init__(
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
self,
is_allowed: bool,
total_limit: int,
remaining_limit: int,
time_until_max: timedelta,
time_until_retry: Optional[timedelta] = None,
) -> None:
"""Initialize a RateLimitResult."""
# pylint: disable=too-many-arguments
if is_allowed and time_until_retry is not None:
raise ValueError(
'time_until_retry must be None if is_allowed is True')
raise ValueError("time_until_retry must be None if is_allowed is True")
self.is_allowed = is_allowed
self.total_limit = total_limit
@ -58,7 +56,7 @@ class RateLimitResult:
)
@classmethod
def unlimited_result(cls) -> 'RateLimitResult':
def unlimited_result(cls) -> "RateLimitResult":
"""Return a "blank" result representing an unlimited action."""
return cls(
is_allowed=True,
@ -68,7 +66,7 @@ class RateLimitResult:
)
@classmethod
def from_redis_cell_result(cls, result: List[int]) -> 'RateLimitResult':
def from_redis_cell_result(cls, result: List[int]) -> "RateLimitResult":
"""Convert the response from CL.THROTTLE command to a RateLimitResult.
CL.THROTTLE responds with an array of 5 integers:
@ -98,34 +96,31 @@ class RateLimitResult:
)
@classmethod
def merged_result(
cls,
results: Sequence['RateLimitResult'],
) -> 'RateLimitResult':
def merged_result(cls, results: Sequence["RateLimitResult"]) -> "RateLimitResult":
"""Merge any number of RateLimitResults into a single result.
Basically, the merged result should be the "most restrictive"
combination of all the source results. That is, it should only allow
the action if *all* of the source results would allow it, the limit
counts should be the lowest of the set, and the waiting times should
be the highest of the set.
Basically, the merged result should be the "most restrictive" combination of all
the source results. That is, it should only allow the action if *all* of the
source results would allow it, the limit counts should be the lowest of the set,
and the waiting times should be the highest of the set.
Note: I think the behavior for time_until_max is not truly correct, but
it should be reasonable for now. Consider a situation like two
"overlapping" limits of 10/min and 100/hour and what the time_until_max
value of the combination should be. It might be a bit tricky.
Note: I think the behavior for time_until_max is not truly correct, but it
should be reasonable for now. Consider a situation like two "overlapping" limits
of 10/min and 100/hour and what the time_until_max value of the combination
should be. It might be a bit tricky.
"""
# if there's only one result, just return that one
if len(results) == 1:
return results[0]
# time_until_retry is a bit trickier than the others because some/all
# of the source values might be None
# time_until_retry is a bit trickier than the others because some/all of the
# source values might be None
if all(r.time_until_retry is None for r in results):
time_until_retry = None
else:
time_until_retry = max(
r.time_until_retry for r in results if r.time_until_retry)
r.time_until_retry for r in results if r.time_until_retry
)
return cls(
is_allowed=all(r.is_allowed for r in results),
@ -140,18 +135,18 @@ class RateLimitResult:
# Retry-After: seconds the client should wait until retrying
if self.time_until_retry:
retry_seconds = int(self.time_until_retry.total_seconds())
response.headers['Retry-After'] = str(retry_seconds)
response.headers["Retry-After"] = str(retry_seconds)
# X-RateLimit-Limit: the total action limit (including used)
response.headers['X-RateLimit-Limit'] = str(self.total_limit)
response.headers["X-RateLimit-Limit"] = str(self.total_limit)
# X-RateLimit-Remaining: remaining actions before client hits the limit
response.headers['X-RateLimit-Remaining'] = str(self.remaining_limit)
response.headers["X-RateLimit-Remaining"] = str(self.remaining_limit)
# X-RateLimit-Reset: epoch timestamp when limit will be back to full
reset_time = utc_now() + self.time_until_max
reset_timestamp = int(reset_time.timestamp())
response.headers['X-RateLimit-Reset'] = str(reset_timestamp)
response.headers["X-RateLimit-Reset"] = str(reset_timestamp)
return response
@ -159,38 +154,37 @@ class RateLimitResult:
class RateLimitedAction:
"""Represents a particular action and the limits on its usage.
This class uses the redis-cell Redis module to implement a Generic Cell
Rate Algorithm (GCRA) for rate-limiting, which includes several desirable
characteristics including a rolling time window and support for "bursts".
This class uses the redis-cell Redis module to implement a Generic Cell Rate
Algorithm (GCRA) for rate-limiting, which includes several desirable characteristics
including a rolling time window and support for "bursts".
"""
def __init__(
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
self,
name: str,
period: timedelta,
limit: int,
max_burst: Optional[int] = None,
by_user: bool = True,
by_ip: bool = True,
redis: Optional[StrictRedis] = None,
) -> None:
"""Initialize the limits on a particular action.
The action will be limited to a maximum of `limit` calls over the time
period specified in `period`. By default, up to half of the actions
inside a period may be used in a "burst", in which no specific time
restrictions are applied. This behavior is controlled by the
`max_burst` argument, which can range from 1 (no burst allowed,
requests must wait at least `period / limit` time between them), up to
the same value as `limit` (the full limit may be used at any rate, but
The action will be limited to a maximum of `limit` calls over the time period
specified in `period`. By default, up to half of the actions inside a period may
be used in a "burst", in which no specific time restrictions are applied. This
behavior is controlled by the `max_burst` argument, which can range from 1 (no
burst allowed, requests must wait at least `period / limit` time between them),
up to the same value as `limit` (the full limit may be used at any rate, but
never more than `limit` inside any given period).
"""
# pylint: disable=too-many-arguments
if max_burst and not 1 <= max_burst <= limit:
raise ValueError('max_burst must be at least 1 and <= limit')
raise ValueError("max_burst must be at least 1 and <= limit")
if not (by_user or by_ip):
raise ValueError('At least one of by_user or by_ip must be True')
raise ValueError("At least one of by_user or by_ip must be True")
self.name = name
self.period = period
@ -205,15 +199,15 @@ class RateLimitedAction:
self.by_user = by_user
self.by_ip = by_ip
# if a redis connection wasn't specified, it will need to be
# initialized before any checks or resets for this action can be done
# if a redis connection wasn't specified, it will need to be initialized before
# any checks or resets for this action can be done
self._redis = redis
@property
def redis(self) -> StrictRedis:
"""Return the redis connection."""
if not self._redis:
raise RateLimitError('No redis connection set')
raise RateLimitError("No redis connection set")
return self._redis
@ -224,19 +218,14 @@ class RateLimitedAction:
def _build_redis_key(self, by_type: str, value: Any) -> str:
"""Build the Redis key where this rate limit is maintained."""
parts = [
'ratelimit',
self.name,
by_type,
str(value),
]
parts = ["ratelimit", self.name, by_type, str(value)]
return ':'.join(parts)
return ":".join(parts)
def _call_redis_command(self, key: str) -> List[int]:
"""Call the redis-cell CL.THROTTLE command for this action."""
return self.redis.execute_command(
'CL.THROTTLE',
"CL.THROTTLE",
key,
self.max_burst - 1,
self.limit,
@ -246,10 +235,9 @@ class RateLimitedAction:
def check_for_user_id(self, user_id: int) -> RateLimitResult:
"""Check whether a particular user_id can perform this action."""
if not self.by_user:
raise RateLimitError(
'check_for_user_id called on non-user-limited action')
raise RateLimitError("check_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result)
@ -257,22 +245,20 @@ class RateLimitedAction:
def reset_for_user_id(self, user_id: int) -> None:
"""Reset the ratelimit on this action for a particular user_id."""
if not self.by_user:
raise RateLimitError(
'reset_for_user_id called on non-user-limited action')
raise RateLimitError("reset_for_user_id called on non-user-limited action")
key = self._build_redis_key('user', user_id)
key = self._build_redis_key("user", user_id)
self.redis.delete(key)
def check_for_ip(self, ip_str: str) -> RateLimitResult:
"""Check whether a particular IP can perform this action."""
if not self.by_ip:
raise RateLimitError(
'check_for_ip called on non-IP-limited action')
raise RateLimitError("check_for_ip called on non-IP-limited action")
# check if ip_str is a valid address, will ValueError if not
ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result)
@ -280,23 +266,21 @@ class RateLimitedAction:
def reset_for_ip(self, ip_str: str) -> None:
"""Reset the ratelimit on this action for a particular IP."""
if not self.by_ip:
raise RateLimitError(
'reset_for_ip called on non-user-limited action')
raise RateLimitError("reset_for_ip called on non-user-limited action")
# check if ip_str is a valid address, will ValueError if not
ip_address(ip_str)
key = self._build_redis_key('ip', ip_str)
key = self._build_redis_key("ip", ip_str)
self.redis.delete(key)
# the actual list of actions with rate-limit restrictions
# each action must have a unique name to prevent key collisions
_RATE_LIMITED_ACTIONS = (
RateLimitedAction('login', timedelta(hours=1), 20),
RateLimitedAction('register', timedelta(hours=1), 50),
RateLimitedAction("login", timedelta(hours=1), 20),
RateLimitedAction("register", timedelta(hours=1), 50),
)
# (public) dict to be able to look up the actions by name
RATE_LIMITED_ACTIONS = {
action.name: action for action in _RATE_LIMITED_ACTIONS}
RATE_LIMITED_ACTIONS = {action.name: action for action in _RATE_LIMITED_ACTIONS}

115
tildes/tildes/lib/string.py

@ -20,29 +20,29 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
slug = original.lower()
# remove apostrophes so contractions don't get broken up by underscores
slug = re.sub("['’]", '', slug)
slug = re.sub("['’]", "", slug)
# replace all remaining non-word characters with underscores
slug = re.sub(r'\W+', '_', slug)
slug = re.sub(r"\W+", "_", slug)
# remove any consecutive underscores
slug = re.sub('_{2,}', '_', slug)
slug = re.sub("_{2,}", "_", slug)
# remove "hanging" underscores on the start and/or end
slug = slug.strip('_')
slug = slug.strip("_")
# url-encode the slug
encoded_slug = quote(slug)
# if the slug's already short enough, just return without worrying about
# how it will need to be truncated
# if the slug's already short enough, just return without worrying about how it will
# need to be truncated
if len(encoded_slug) <= max_length:
return encoded_slug
# Truncating a url-encoded slug can be tricky if there are any multi-byte
# unicode characters, since the %-encoded forms of them can be quite long.
# Check to see if the slug looks like it might contain any of those.
maybe_multi_bytes = bool(re.search('%..%', encoded_slug))
# Truncating a url-encoded slug can be tricky if there are any multi-byte unicode
# characters, since the %-encoded forms of them can be quite long. Check to see if
# the slug looks like it might contain any of those.
maybe_multi_bytes = bool(re.search("%..%", encoded_slug))
# if that matched, we need to take a more complicated approach
if maybe_multi_bytes:
@ -50,19 +50,15 @@ def convert_to_url_slug(original: str, max_length: int = 100) -> str:
# simple truncate - break at underscore if possible, no overflow string
return truncate_string(
encoded_slug,
max_length,
truncate_at_chars='_',
overflow_str=None,
encoded_slug, max_length, truncate_at_chars="_", overflow_str=None
)
def _truncate_multibyte_slug(original: str, max_length: int) -> str:
"""URL-encode and truncate a slug known to contain multi-byte chars."""
# instead of the normal method of truncating "backwards" from the end of
# the string, build it up one encoded character at a time from the start
# until it's too long
encoded_slug = ''
# instead of the normal method of truncating "backwards" from the end of the string,
# build it up one encoded character at a time from the start until it's too long
encoded_slug = ""
for character in original:
encoded_char = quote(character)
@ -72,17 +68,16 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
encoded_slug += encoded_char
# Now we know that the string is made up of "whole" characters and is close
# to the maximum length. We'd still like to truncate it at an underscore if
# possible, but some languages like Japanese and Chinese won't have many
# (or any) underscores in the slug, and we could end up losing a lot of the
# characters. So try breaking it at an underscore, but if it means more
# than 30% of the slug gets cut off, just leave it alone. This means that
# some url slugs in other languages will end in partial words, but
# determining the word edges is not simple.
# Now we know that the string is made up of "whole" characters and is close to the
# maximum length. We'd still like to truncate it at an underscore if possible, but
# some languages like Japanese and Chinese won't have many (or any) underscores in
# the slug, and we could end up losing a lot of the characters. So try breaking it
# at an underscore, but if it means more than 30% of the slug gets cut off, just
# leave it alone. This means that some url slugs in other languages will end in
# partial words, but determining the word edges is not simple.
acceptable_truncation = 0.7
truncated_slug = truncate_string_at_char(encoded_slug, '_')
truncated_slug = truncate_string_at_char(encoded_slug, "_")
if len(truncated_slug) / len(encoded_slug) >= acceptable_truncation:
return truncated_slug
@ -91,33 +86,31 @@ def _truncate_multibyte_slug(original: str, max_length: int) -> str:
def truncate_string(
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = '...',
original: str,
length: int,
truncate_at_chars: Optional[str] = None,
overflow_str: Optional[str] = "...",
) -> str:
"""Truncate a string to be no longer than a specified length.
If `truncate_at_chars` is specified (as a string, one or more characters),
the truncation will happen at the last occurrence of any of those chars
inside the remaining string after it has been initially cut down to the
desired length.
If `truncate_at_chars` is specified (as a string, one or more characters), the
truncation will happen at the last occurrence of any of those chars inside the
remaining string after it has been initially cut down to the desired length.
`overflow_str` will be appended to the result, and its length is included
in the final string length. So for example, if `overflow_str` has a length
of 3 and the target length is 10, at most 7 characters of the original
string will be kept.
`overflow_str` will be appended to the result, and its length is included in the
final string length. So for example, if `overflow_str` has a length of 3 and the
target length is 10, at most 7 characters of the original string will be kept.
"""
if overflow_str is None:
overflow_str = ''
overflow_str = ""
# no need to do anything if the string is already short enough
if len(original) <= length:
return original
# cut the string down to the max desired length (leaving space for the
# overflow string if one is specified)
truncated = original[:length - len(overflow_str)]
# cut the string down to the max desired length (leaving space for the overflow
# string if one is specified)
truncated = original[: length - len(overflow_str)]
# if we don't want to truncate at particular characters, we're done
if not truncate_at_chars:
@ -132,17 +125,17 @@ def truncate_string(
def truncate_string_at_char(original: str, valid_chars: str) -> str:
"""Truncate a string at the last occurrence of a particular character.
Supports passing multiple valid characters (as a string) for `valid_chars`,
for example valid_chars='.?!' would truncate at the "right-most" occurrence
of any of those 3 characters in the string.
Supports passing multiple valid characters (as a string) for `valid_chars`, for
example valid_chars='.?!' would truncate at the "right-most" occurrence of any of
those 3 characters in the string.
"""
# work backwards through the string until we find one of the chars we want
for num_from_end, char in enumerate(reversed(original), start=1):
if char in valid_chars:
break
else:
# the loop didn't break, so we looked through the entire string and
# didn't find any of the desired characters - can't do anything
# the loop didn't break, so we looked through the entire string and didn't find
# any of the desired characters - can't do anything
return original
# a truncation char was found, so -num_from_end is the position to stop at
@ -153,21 +146,21 @@ def truncate_string_at_char(original: str, valid_chars: str) -> str:
def simplify_string(original: str) -> str:
"""Sanitize a string for usage in places where we need a "simple" one.
This function is useful for sanitizing strings so that they're suitable to
be used in places like topic titles, message subjects, and so on.
This function is useful for sanitizing strings so that they're suitable to be used
in places like topic titles, message subjects, and so on.
Strings processed by this function:
* have unicode chars from the "separator" category replaced with spaces
* have unicode chars from the "other" category stripped out, except for
newlines, which are replaced with spaces
* have unicode chars from the "other" category stripped out, except for newlines,
which are replaced with spaces
* have leading and trailing whitespace removed
* have multiple consecutive spaces collapsed into a single space
"""
simplified = _sanitize_characters(original)
# replace consecutive spaces with a single space
simplified = re.sub(r'\s{2,}', ' ', simplified)
simplified = re.sub(r"\s{2,}", " ", simplified)
# remove any remaining leading/trailing whitespace
simplified = simplified.strip()
@ -182,16 +175,16 @@ def _sanitize_characters(original: str) -> str:
for char in original:
category = unicodedata.category(char)
if category.startswith('Z'):
if category.startswith("Z"):
# "separator" chars - replace with a normal space
final_characters.append(' ')
elif category.startswith('C'):
# "other" chars (control, formatting, etc.) - filter them out
# except for newlines, which are replaced with normal spaces
if char == '\n':
final_characters.append(' ')
final_characters.append(" ")
elif category.startswith("C"):
# "other" chars (control, formatting, etc.) - filter them out except for
# newlines, which are replaced with normal spaces
if char == "\n":
final_characters.append(" ")
else:
# any other type of character, just keep it
final_characters.append(char)
return ''.join(final_characters)
return "".join(final_characters)

4
tildes/tildes/lib/url.py

@ -8,9 +8,9 @@ def get_domain_from_url(url: str, strip_www: bool = True) -> str:
domain = urlparse(url).netloc
if not domain:
raise ValueError('Invalid url or domain could not be determined')
raise ValueError("Invalid url or domain could not be determined")
if strip_www and domain.startswith('www.'):
if strip_www and domain.startswith("www."):
domain = domain[4:]
return domain

64
tildes/tildes/metrics.py

@ -1,7 +1,7 @@
"""Contains Prometheus metric objects and functions for instrumentation."""
# The prometheus_client classes work in a pretty crazy way, need to disable
# these pylint checks to avoid errors
# The prometheus_client classes work in a pretty crazy way, need to disable these pylint
# checks to avoid errors
# pylint: disable=no-value-for-parameter,redundant-keyword-arg
from typing import Callable
@ -11,50 +11,32 @@ from prometheus_client.core import _LabelWrapper
_COUNTERS = {
'votes': Counter(
'tildes_votes_total',
'Votes',
labelnames=['target_type'],
),
'comments': Counter('tildes_comments_total', 'Comments'),
'invite_code_failures': Counter(
'tildes_invite_code_failures_total',
'Invite Code Failures',
),
'logins': Counter('tildes_logins_total', 'Login Attempts'),
'login_failures': Counter(
'tildes_login_failures_total',
'Login Failures',
),
'messages': Counter(
'tildes_messages_total',
'Messages',
labelnames=['type'],
),
'registrations': Counter(
'tildes_registrations_total',
'User Registrations',
),
'topics': Counter('tildes_topics_total', 'Topics', labelnames=['type']),
'subscriptions': Counter('tildes_subscriptions_total', 'Subscriptions'),
'unsubscriptions': Counter(
'tildes_unsubscriptions_total',
'Unsubscriptions',
"votes": Counter("tildes_votes_total", "Votes", labelnames=["target_type"]),
"comments": Counter("tildes_comments_total", "Comments"),
"invite_code_failures": Counter(
"tildes_invite_code_failures_total", "Invite Code Failures"
),
"logins": Counter("tildes_logins_total", "Login Attempts"),
"login_failures": Counter("tildes_login_failures_total", "Login Failures"),
"messages": Counter("tildes_messages_total", "Messages", labelnames=["type"]),
"registrations": Counter("tildes_registrations_total", "User Registrations"),
"topics": Counter("tildes_topics_total", "Topics", labelnames=["type"]),
"subscriptions": Counter("tildes_subscriptions_total", "Subscriptions"),
"unsubscriptions": Counter("tildes_unsubscriptions_total", "Unsubscriptions"),
}
_HISTOGRAMS = {
'markdown_processing': Histogram(
'tildes_markdown_processing_seconds',
'Markdown processing',
"markdown_processing": Histogram(
"tildes_markdown_processing_seconds",
"Markdown processing",
buckets=[.001, .0025, .005, .01, 0.025, .05, .1, .5, 1.0],
),
'comment_tree_sorting': Histogram(
'tildes_comment_tree_sorting_seconds',
'Comment tree sorting time',
labelnames=['num_comments_range', 'order'],
"comment_tree_sorting": Histogram(
"tildes_comment_tree_sorting_seconds",
"Comment tree sorting time",
labelnames=["num_comments_range", "order"],
buckets=[.00001, .0001, .001, .01, .05, .1, .5, 1.0],
)
),
}
@ -63,7 +45,7 @@ def incr_counter(name: str, amount: int = 1, **labels: str) -> None:
try:
counter = _COUNTERS[name]
except KeyError:
raise ValueError('Invalid counter name')
raise ValueError("Invalid counter name")
if isinstance(counter, _LabelWrapper):
counter = counter.labels(**labels)
@ -76,7 +58,7 @@ def get_histogram(name: str, **labels: str) -> Histogram:
try:
hist = _HISTOGRAMS[name]
except KeyError:
raise ValueError('Invalid histogram name')
raise ValueError("Invalid histogram name")
if isinstance(hist, _LabelWrapper):
hist = hist.labels(**labels)

125
tildes/tildes/models/comment/comment.py

@ -5,14 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import deferred, relationship
from sqlalchemy.sql.expression import text
@ -36,23 +29,23 @@ class Comment(DatabaseModel):
Trigger behavior:
Incoming:
- num_votes will be incremented and decremented by insertions and
deletions in comment_votes.
- num_votes will be incremented and decremented by insertions and deletions in
comment_votes.
Outgoing:
- Inserting or deleting rows, or updating is_deleted/is_removed to
change visibility will increment or decrement num_comments
accordingly on the relevant topic.
- Inserting a row will increment num_comments on any topic_visit rows
for the comment's author and the relevant topic.
- Inserting a new comment or updating is_deleted or is_removed will
update last_activity_time on the relevant topic.
- Inserting or deleting rows, or updating is_deleted/is_removed to change
visibility will increment or decrement num_comments accordingly on the
relevant topic.
- Inserting a row will increment num_comments on any topic_visit rows for the
comment's author and the relevant topic.
- Inserting a new comment or updating is_deleted or is_removed will update
last_activity_time on the relevant topic.
- Setting is_deleted or is_removed to true will delete any rows in
comment_notifications related to the comment.
- Changing is_deleted or is_removed will adjust num_comments on all
topic_visit rows for the relevant topic, where the visit_time was
after the time the comment was originally posted.
- Inserting a row or updating markdown will send a rabbitmq message
for "comment.created" or "comment.edited" respectively.
- Changing is_deleted or is_removed will adjust num_comments on all topic_visit
rows for the relevant topic, where the visit_time was after the time the
comment was originally posted.
- Inserting a row or updating markdown will send a rabbitmq message for
"comment.created" or "comment.edited" respectively.
Internal:
- deleted_time will be set or unset when is_deleted is changed
- removed_time will be set or unset when is_removed is changed
@ -60,47 +53,40 @@ class Comment(DatabaseModel):
schema_class = CommentSchema
__tablename__ = 'comments'
__tablename__ = "comments"
comment_id: int = Column(Integer, primary_key=True)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
index=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, index=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
parent_comment_id: Optional[int] = Column(
Integer,
ForeignKey('comments.comment_id'),
index=True,
Integer, ForeignKey("comments.comment_id"), index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column(Boolean, nullable=False, server_default='false')
is_removed: bool = Column(Boolean, nullable=False, server_default="false")
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
_markdown: str = deferred(Column('markdown', Text, nullable=False))
_markdown: str = deferred(Column("markdown", Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
user: User = relationship('User', lazy=False, innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
parent_comment: Optional['Comment'] = relationship(
'Comment', uselist=False, remote_side=[comment_id])
user: User = relationship("User", lazy=False, innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
parent_comment: Optional["Comment"] = relationship(
"Comment", uselist=False, remote_side=[comment_id]
)
@hybrid_property
def markdown(self) -> str:
@ -117,20 +103,19 @@ class Comment(DatabaseModel):
self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now()
def __repr__(self) -> str:
"""Display the comment's ID as its repr format."""
return f'<Comment ({self.comment_id})>'
return f"<Comment ({self.comment_id})>"
def __init__(
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional['Comment'] = None,
self,
topic: Topic,
author: User,
markdown: str,
parent_comment: Optional["Comment"] = None,
) -> None:
"""Create a new comment."""
self.topic = topic
@ -142,7 +127,7 @@ class Comment(DatabaseModel):
self.markdown = markdown
incr_counter('comments')
incr_counter("comments")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
@ -156,49 +141,49 @@ class Comment(DatabaseModel):
# - removed comments can only be viewed by admins and the author
# - otherwise, everyone can view
if self.is_removed:
acl.append((Allow, 'admin', 'view'))
acl.append((Allow, self.user_id, 'view'))
acl.append((Deny, Everyone, 'view'))
acl.append((Allow, "admin", "view"))
acl.append((Allow, self.user_id, "view"))
acl.append((Deny, Everyone, "view"))
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# vote:
# - removed comments can't be voted on by anyone
# - otherwise, logged-in users except the author can vote
if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# tag:
# - temporary: nobody can tag comments
acl.append((Deny, Everyone, 'tag'))
acl.append((Deny, Everyone, "tag"))
# reply:
# - removed comments can't be replied to by anyone
# - if the topic is locked, only admins can reply
# - otherwise, logged-in users can reply
if self.is_removed:
acl.append((Deny, Everyone, 'reply'))
acl.append((Deny, Everyone, "reply"))
if self.topic.is_locked:
acl.append((Allow, 'admin', 'reply'))
acl.append((Deny, Everyone, 'reply'))
acl.append((Allow, "admin", "reply"))
acl.append((Deny, Everyone, "reply"))
acl.append((Allow, Authenticated, 'reply'))
acl.append((Allow, Authenticated, "reply"))
# edit:
# - only the author can edit
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete:
# - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# mark_read:
# - logged-in users can mark comments read
acl.append((Allow, Authenticated, 'mark_read'))
acl.append((Allow, Authenticated, "mark_read"))
acl.append(DENY_ALL)
@ -220,7 +205,7 @@ class Comment(DatabaseModel):
@property
def permalink(self) -> str:
"""Return the permalink for this comment."""
return f'{self.topic.permalink}#comment-{self.comment_id36}'
return f"{self.topic.permalink}#comment-{self.comment_id36}"
@property
def parent_comment_permalink(self) -> str:
@ -228,7 +213,7 @@ class Comment(DatabaseModel):
if not self.parent_comment_id:
raise AttributeError
return f'{self.topic.permalink}#comment-{self.parent_comment_id36}'
return f"{self.topic.permalink}#comment-{self.parent_comment_id36}"
@property
def tag_counts(self) -> Counter:

94
tildes/tildes/models/comment/comment_notification.py

@ -22,45 +22,33 @@ class CommentNotification(DatabaseModel):
Trigger behavior:
Incoming:
- Rows will be deleted if the relevant comment has is_deleted set to
true.
- Rows will be deleted if the relevant comment has is_deleted set to true.
Outgoing:
- Inserting, deleting, or updating is_unread will increment or
decrement num_unread_notifications for the relevant user.
- Inserting, deleting, or updating is_unread will increment or decrement
num_unread_notifications for the relevant user.
"""
__tablename__ = 'comment_notifications'
__tablename__ = "comment_notifications"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
notification_type: CommentNotificationType = Column(
ENUM(CommentNotificationType), nullable=False)
ENUM(CommentNotificationType), nullable=False
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
is_unread: bool = Column(
Boolean, nullable=False, server_default='true', index=True)
is_unread: bool = Column(Boolean, nullable=False, server_default="true", index=True)
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__(
self,
user: User,
comment: Comment,
notification_type: CommentNotificationType,
self, user: User, comment: Comment, notification_type: CommentNotificationType
) -> None:
"""Create a new notification for a user from a comment."""
self.user = user
@ -70,7 +58,7 @@ class CommentNotification(DatabaseModel):
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
acl = []
acl.append((Allow, self.user_id, 'mark_read'))
acl.append((Allow, self.user_id, "mark_read"))
acl.append(DENY_ALL)
return acl
@ -91,17 +79,12 @@ class CommentNotification(DatabaseModel):
@classmethod
def get_mentions_for_comment(
cls,
db_session: Session,
comment: Comment,
) -> List['CommentNotification']:
cls, db_session: Session, comment: Comment
) -> List["CommentNotification"]:
"""Get a list of notifications for user mentions in the comment."""
notifications = []
raw_names = re.findall(
LinkifyFilter.USERNAME_REFERENCE_REGEX,
comment.markdown,
)
raw_names = re.findall(LinkifyFilter.USERNAME_REFERENCE_REGEX, comment.markdown)
users_to_mention = (
db_session.query(User)
.filter(User.username.in_(raw_names)) # type: ignore
@ -124,38 +107,37 @@ class CommentNotification(DatabaseModel):
continue
mention_notification = cls(
user, comment, CommentNotificationType.USER_MENTION)
user, comment, CommentNotificationType.USER_MENTION
)
notifications.append(mention_notification)
return notifications
@staticmethod
def prevent_duplicate_notifications(
db_session: Session,
comment: Comment,
new_notifications: List['CommentNotification'],
) -> Tuple[List['CommentNotification'], List['CommentNotification']]:
db_session: Session,
comment: Comment,
new_notifications: List["CommentNotification"],
) -> Tuple[List["CommentNotification"], List["CommentNotification"]]:
"""Filter new notifications for edited comments.
Protect against sending a notification for the same comment to
the same user twice. Edits can sent notifications to users
now mentioned in the content, but only if they weren't sent
a notification for that comment before.
Protect against sending a notification for the same comment to the same user
twice. Edits can sent notifications to users now mentioned in the content, but
only if they weren't sent a notification for that comment before.
This method returns a tuple of lists of this class. The first
item is the notifications that were previously sent for this
comment but need to be deleted (i.e. mentioned username was edited
out of the comment), and the second item is the notifications
that need to be added, as they're new.
This method returns a tuple of lists of this class. The first item is the
notifications that were previously sent for this comment but need to be deleted
(i.e. mentioned username was edited out of the comment), and the second item is
the notifications that need to be added, as they're new.
"""
previous_notifications = (
db_session
.query(CommentNotification)
db_session.query(CommentNotification)
.filter(
CommentNotification.comment_id == comment.comment_id,
CommentNotification.notification_type ==
CommentNotificationType.USER_MENTION,
).all()
CommentNotification.notification_type
== CommentNotificationType.USER_MENTION,
)
.all()
)
new_mention_user_ids = [
@ -167,12 +149,14 @@ class CommentNotification(DatabaseModel):
]
to_delete = [
notification for notification in previous_notifications
notification
for notification in previous_notifications
if notification.user.user_id not in new_mention_user_ids
]
to_add = [
notification for notification in new_notifications
notification
for notification in new_notifications
if notification.user.user_id not in previous_mention_user_ids
]

10
tildes/tildes/models/comment/comment_notification_query.py

@ -17,7 +17,7 @@ class CommentNotificationQuery(ModelQuery):
"""Initialize a CommentNotificationQuery for the request."""
super().__init__(CommentNotification, request)
def _attach_extra_data(self) -> 'CommentNotificationQuery':
def _attach_extra_data(self) -> "CommentNotificationQuery":
"""Attach the user's comment votes to the query."""
vote_subquery = (
self.request.query(CommentVote)
@ -26,16 +26,16 @@ class CommentNotificationQuery(ModelQuery):
CommentVote.user == self.request.user,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)
def join_all_relationships(self) -> 'CommentNotificationQuery':
def join_all_relationships(self) -> "CommentNotificationQuery":
"""Eagerly join the comment, topic, and group to the notification."""
self = self.options(
joinedload(CommentNotification.comment)
.joinedload('topic')
.joinedload('group')
.joinedload("topic")
.joinedload("group")
)
return self

11
tildes/tildes/models/comment/comment_query.py

@ -15,20 +15,19 @@ class CommentQuery(PaginatedQuery):
def __init__(self, request: Request) -> None:
"""Initialize a CommentQuery for the request.
If the user is logged in, additional user-specific data will be fetched
along with the comments. For the moment, this is whether the user has
voted on them.
If the user is logged in, additional user-specific data will be fetched along
with the comments. For the moment, this is whether the user has voted on them.
"""
super().__init__(Comment, request)
def _attach_extra_data(self) -> 'CommentQuery':
def _attach_extra_data(self) -> "CommentQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
return self._attach_vote_data()
def _attach_vote_data(self) -> 'CommentQuery':
def _attach_vote_data(self) -> "CommentQuery":
"""Add a subquery to include whether the user has voted."""
vote_subquery = (
self.request.query(CommentVote)
@ -37,7 +36,7 @@ class CommentQuery(PaginatedQuery):
CommentVote.user_id == self.request.user.user_id,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)

29
tildes/tildes/models/comment/comment_tag.py

@ -16,37 +16,24 @@ from .comment import Comment
class CommentTag(DatabaseModel):
"""Model for the tags attached to comments by users."""
__tablename__ = 'comment_tags'
__tablename__ = "comment_tags"
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
tag: CommentTagOption = Column(
ENUM(CommentTagOption), nullable=False, primary_key=True)
ENUM(CommentTagOption), nullable=False, primary_key=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
comment: Comment = relationship(
Comment, backref=backref('tags', lazy=False))
comment: Comment = relationship(Comment, backref=backref("tags", lazy=False))
def __init__(
self,
comment: Comment,
user: User,
tag: CommentTagOption,
) -> None:
def __init__(self, comment: Comment, user: User, tag: CommentTagOption) -> None:
"""Add a new tag to a comment."""
self.comment_id = comment.comment_id
self.user_id = user.user_id

46
tildes/tildes/models/comment/comment_tree.py

@ -14,21 +14,17 @@ class CommentTree:
The Comment objects held by this class have additional attributes added:
- `replies`: the list of all immediate children to that comment
- `has_visible_descendant`: whether the comment has any visible
descendants (if not, it can be pruned from the tree)
- `has_visible_descendant`: whether the comment has any visible descendants (if
not, it can be pruned from the tree)
"""
def __init__(
self,
comments: Sequence[Comment],
sort: CommentSortOption,
) -> None:
def __init__(self, comments: Sequence[Comment], sort: CommentSortOption) -> None:
"""Create a sorted CommentTree from a flat list of Comments."""
self.tree: List[Comment] = []
self.sort = sort
# sort the comments by date, since replies will always be posted later
# this will ensure that parent comments are always processed first
# sort the comments by date, since replies will always be posted later this will
# ensure that parent comments are always processed first
self.comments = sorted(comments, key=lambda c: c.created_time)
# if there aren't any comments, we can just bail out here
@ -37,11 +33,10 @@ class CommentTree:
self._build_tree()
# The method of building the tree already sorts it by posting time, so
# there's no need to sort again if that's the desired sorting. Note
# also that because _sort_tree() uses sorted() which is a stable sort,
# this means that the "secondary sort" will always be by posting time
# as well.
# The method of building the tree already sorts it by posting time, so there's
# no need to sort again if that's the desired sorting. Note also that because
# _sort_tree() uses sorted() which is a stable sort, this means that the
# "secondary sort" will always be by posting time as well.
if sort != CommentSortOption.POSTED:
with self._sorting_histogram().time():
self.tree = self._sort_tree(self.tree, self.sort)
@ -76,15 +71,12 @@ class CommentTree:
self.tree.append(comment)
@staticmethod
def _sort_tree(
tree: List[Comment],
sort: CommentSortOption,
) -> List[Comment]:
def _sort_tree(tree: List[Comment], sort: CommentSortOption) -> List[Comment]:
"""Sort the tree by the desired ordering (recursively).
Because Python's sorted() function is stable, the ordering of any
comments that compare equal on the sorting method will be the same as
the order that they were originally in when passed to this function.
Because Python's sorted() function is stable, the ordering of any comments that
compare equal on the sorting method will be the same as the order that they were
originally in when passed to this function.
"""
if sort == CommentSortOption.NEWEST:
tree = sorted(tree, key=lambda c: c.created_time, reverse=True)
@ -149,18 +141,18 @@ class CommentTree:
# make an "order of magnitude" label based on the number of comments
if num_comments == 0:
raise ValueError('Attempting to time an empty comment tree sort')
raise ValueError("Attempting to time an empty comment tree sort")
if num_comments < 10:
num_comments_range = '1 - 9'
num_comments_range = "1 - 9"
elif num_comments < 100:
num_comments_range = '10 - 99'
num_comments_range = "10 - 99"
elif num_comments < 1000:
num_comments_range = '100 - 999'
num_comments_range = "100 - 999"
else:
num_comments_range = '1000+'
num_comments_range = "1000+"
return get_histogram(
'comment_tree_sorting',
"comment_tree_sorting",
num_comments_range=num_comments_range,
order=self.sort.name,
)

24
tildes/tildes/models/comment/comment_vote.py

@ -17,37 +17,31 @@ class CommentVote(DatabaseModel):
Trigger behavior:
Outgoing:
- Inserting or deleting a row will increment or decrement the num_votes
column for the relevant comment.
- Inserting or deleting a row will increment or decrement the num_votes column
for the relevant comment.
"""
__tablename__ = 'comment_votes'
__tablename__ = "comment_votes"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
comment_id: int = Column(
Integer,
ForeignKey('comments.comment_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("comments.comment_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True)
comment: Comment = relationship('Comment', innerjoin=True)
user: User = relationship("User", innerjoin=True)
comment: Comment = relationship("Comment", innerjoin=True)
def __init__(self, user: User, comment: Comment) -> None:
"""Create a new vote on a comment."""
self.user = user
self.comment = comment
incr_counter('votes', target_type='comment')
incr_counter("votes", target_type="comment")

71
tildes/tildes/models/database_model.py

@ -11,36 +11,31 @@ from sqlalchemy.schema import MetaData
from sqlalchemy.sql.schema import Table
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
# SQLAlchemy naming convention for constraints and indexes
NAMING_CONVENTION = {
'pk': 'pk_%(table_name)s',
'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s',
'ix': 'ix_%(table_name)s_%(column_0_name)s',
'ck': 'ck_%(table_name)s_%(constraint_name)s',
'uq': 'uq_%(table_name)s_%(column_0_name)s',
"pk": "pk_%(table_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"ix": "ix_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
}
def attach_set_listener(
class_: Type['DatabaseModelBase'],
attribute: str,
instance: 'DatabaseModelBase',
class_: Type["DatabaseModelBase"], attribute: str, instance: "DatabaseModelBase"
) -> None:
"""Attach the SQLAlchemy ORM "set" attribute listener."""
# pylint: disable=unused-argument
def set_handler(
target: 'DatabaseModelBase',
value: Any,
oldvalue: Any,
initiator: Any,
target: "DatabaseModelBase", value: Any, oldvalue: Any, initiator: Any
) -> Any:
"""Handle an SQLAlchemy ORM "set" attribute event."""
# pylint: disable=protected-access
return target._validate_new_value(attribute, value)
event.listen(instance, 'set', set_handler, retval=True)
event.listen(instance, "set", set_handler, retval=True)
class DatabaseModelBase:
@ -56,8 +51,8 @@ class DatabaseModelBase:
if not isinstance(other, self.__class__):
return NotImplemented
# loop over all the columns in the primary key - if any don't match,
# return False, otherwise return True if we get through all of them
# loop over all the columns in the primary key - if any don't match, return
# False, otherwise return True if we get through all of them
for column in self.__table__.primary_key:
if getattr(self, column.name) != getattr(other, column.name):
return False
@ -67,12 +62,11 @@ class DatabaseModelBase:
def __hash__(self) -> int:
"""Return the hash value of the model.
This is implemented by mixing together the hash values of the primary
key columns used in __eq__, as recommended in the Python documentation.
This is implemented by mixing together the hash values of the primary key
columns used in __eq__, as recommended in the Python documentation.
"""
primary_key_values = tuple(
getattr(self, column.name)
for column in self.__table__.primary_key
getattr(self, column.name) for column in self.__table__.primary_key
)
return hash(primary_key_values)
@ -82,7 +76,7 @@ class DatabaseModelBase:
if not self.schema_class:
raise AttributeError
if not hasattr(self, '_schema'):
if not hasattr(self, "_schema"):
self._schema = self.schema_class(partial=True) # noqa
return self._schema
@ -90,29 +84,26 @@ class DatabaseModelBase:
def _validate_new_value(self, attribute: str, value: Any) -> Any:
"""Validate the new value for a column.
This function will be attached to the SQLAlchemy ORM attribute event
for "set" and will be called whenever a new value is assigned to any of
a model's column attributes. It works by deserializing/loading the new
value through the marshmallow schema associated with the model class
(by its `schema` class attribute).
This function will be attached to the SQLAlchemy ORM attribute event for "set"
and will be called whenever a new value is assigned to any of a model's column
attributes. It works by deserializing/loading the new value through the
marshmallow schema associated with the model class (by its `schema` class
attribute).
The deserialization process can modify the value if desired (for
sanitization), or raise an exception which will prevent the assignment
from happening at all.
The deserialization process can modify the value if desired (for sanitization),
or raise an exception which will prevent the assignment from happening at all.
Note that if the schema does not have a Field defined for the column,
or the Field is declared dump_only, no validation/sanitization will be
applied.
Note that if the schema does not have a Field defined for the column, or the
Field is declared dump_only, no validation/sanitization will be applied.
"""
if not self.schema_class:
return value
# This is a bit "magic", but simplifies the interaction between this
# validation and SQLAlchemy hybrid properties. If the attribute being
# set starts with an underscore, assume that it's due to being set up
# as a hybrid property, and remove the underscore prefix when looking
# for a field to validate against.
if attribute.startswith('_'):
# This is a bit "magic", but simplifies the interaction between this validation
# and SQLAlchemy hybrid properties. If the attribute being set starts with an
# underscore, assume that it's due to being set up as a hybrid property, and
# remove the underscore prefix when looking for a field to validate against.
if attribute.startswith("_"):
attribute = attribute[1:]
field = self.schema.fields.get(attribute)
@ -126,13 +117,13 @@ class DatabaseModelBase:
DatabaseModel = declarative_base( # pylint: disable=invalid-name
cls=DatabaseModelBase,
name='DatabaseModel',
name="DatabaseModel",
metadata=MetaData(naming_convention=NAMING_CONVENTION),
)
# attach the listener for SQLAlchemy ORM attribute "set" events to all models
event.listen(DatabaseModel, 'attribute_instrument', attach_set_listener)
event.listen(DatabaseModel, "attribute_instrument", attach_set_listener)
# associate JSONB columns with MutableDict so value changes are detected
MutableDict.associate_with(JSONB)

54
tildes/tildes/models/group/group.py

@ -4,15 +4,7 @@ from datetime import datetime
from typing import Any, Optional, Sequence, Tuple
from pyramid.security import Allow, Authenticated, Deny, DENY_ALL, Everyone
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
Index,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import Boolean, CheckConstraint, Column, Index, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text
from sqlalchemy_utils import Ltree, LtreeType
@ -25,13 +17,13 @@ class Group(DatabaseModel):
Trigger behavior:
Incoming:
- num_subscriptions will be incremented and decremented by insertions
and deletions in group_subscriptions.
- num_subscriptions will be incremented and decremented by insertions and
deletions in group_subscriptions.
"""
schema_class = GroupSchema
__tablename__ = 'groups'
__tablename__ = "groups"
group_id: int = Column(Integer, primary_key=True)
path: Ltree = Column(LtreeType, nullable=False, index=True, unique=True)
@ -39,36 +31,34 @@ class Group(DatabaseModel):
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
short_description: Optional[str] = Column(
Text,
CheckConstraint(
f'LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}',
name='short_description_length',
)
f"LENGTH(short_description) <= {SHORT_DESCRIPTION_MAX_LENGTH}",
name="short_description_length",
),
)
num_subscriptions: int = Column(
Integer, nullable=False, server_default='0')
num_subscriptions: int = Column(Integer, nullable=False, server_default="0")
is_admin_posting_only: bool = Column(
Boolean, nullable=False, server_default='false')
# Create a GiST index on path as well as the btree one that will be created
# by the index=True/unique=True keyword args to Column above. The GiST
# index supports additional operators for ltree queries: @>, <@, @, ~, ?
__table_args__ = (
Index('ix_groups_path_gist', path, postgresql_using='gist'),
Boolean, nullable=False, server_default="false"
)
# Create a GiST index on path as well as the btree one that will be created by the
# index=True/unique=True keyword args to Column above. The GiST index supports
# additional operators for ltree queries: @>, <@, @, ~, ?
__table_args__ = (Index("ix_groups_path_gist", path, postgresql_using="gist"),)
def __repr__(self) -> str:
"""Display the group's path and ID as its repr format."""
return f'<Group {self.path} ({self.group_id})>'
return f"<Group {self.path} ({self.group_id})>"
def __str__(self) -> str:
"""Use the group path for the string representation."""
return str(self.path)
def __lt__(self, other: 'Group') -> bool:
def __lt__(self, other: "Group") -> bool:
"""Order groups by their string representation."""
return str(self) < str(other)
@ -83,20 +73,20 @@ class Group(DatabaseModel):
# view:
# - all groups can be viewed by everyone
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# subscribe:
# - all groups can be subscribed to by logged-in users
acl.append((Allow, Authenticated, 'subscribe'))
acl.append((Allow, Authenticated, "subscribe"))
# post_topic:
# - only admins can post in admin-posting-only groups
# - otherwise, all logged-in users can post
if self.is_admin_posting_only:
acl.append((Allow, 'admin', 'post_topic'))
acl.append((Deny, Everyone, 'post_topic'))
acl.append((Allow, "admin", "post_topic"))
acl.append((Deny, Everyone, "post_topic"))
acl.append((Allow, Authenticated, 'post_topic'))
acl.append((Allow, Authenticated, "post_topic"))
acl.append(DENY_ALL)

11
tildes/tildes/models/group/group_query.py

@ -15,20 +15,19 @@ class GroupQuery(ModelQuery):
def __init__(self, request: Request) -> None:
"""Initialize a GroupQuery for the request.
If the user is logged in, additional user-specific data will be fetched
along with the groups. For the moment, this is whether the user is
subscribed to them.
If the user is logged in, additional user-specific data will be fetched along
with the groups. For the moment, this is whether the user is subscribed to them.
"""
super().__init__(Group, request)
def _attach_extra_data(self) -> 'GroupQuery':
def _attach_extra_data(self) -> "GroupQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
return self._attach_subscription_data()
def _attach_subscription_data(self) -> 'GroupQuery':
def _attach_subscription_data(self) -> "GroupQuery":
"""Add a subquery to include whether the user is subscribed."""
subscription_subquery = (
self.request.query(GroupSubscription)
@ -37,7 +36,7 @@ class GroupQuery(ModelQuery):
GroupSubscription.user == self.request.user,
)
.exists()
.label('user_subscribed')
.label("user_subscribed")
)
return self.add_columns(subscription_subquery)

24
tildes/tildes/models/group/group_subscription.py

@ -17,37 +17,31 @@ class GroupSubscription(DatabaseModel):
Trigger behavior:
Outgoing:
- Inserting or deleting a row will increment or decrement the
num_subscriptions column for the relevant group.
- Inserting or deleting a row will increment or decrement the num_subscriptions
column for the relevant group.
"""
__tablename__ = 'group_subscriptions'
__tablename__ = "group_subscriptions"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True, backref='subscriptions')
group: Group = relationship('Group', innerjoin=True, lazy=False)
user: User = relationship("User", innerjoin=True, backref="subscriptions")
group: Group = relationship("Group", innerjoin=True, lazy=False)
def __init__(self, user: User, group: Group) -> None:
"""Create a new subscription to a group."""
self.user = user
self.group = group
incr_counter('subscriptions')
incr_counter("subscriptions")

140
tildes/tildes/models/log/log.py

@ -3,15 +3,7 @@
from typing import Any, Dict, Optional
from pyramid.request import Request
from sqlalchemy import (
BigInteger,
Column,
event,
ForeignKey,
Integer,
Table,
TIMESTAMP,
)
from sqlalchemy import BigInteger, Column, event, ForeignKey, Integer, Table, TIMESTAMP
from sqlalchemy.dialects.postgresql import ENUM, INET, JSONB
from sqlalchemy.engine import Connection
from sqlalchemy.ext.declarative import declared_attr
@ -23,7 +15,7 @@ from tildes.models import DatabaseModel
from tildes.models.topic import Topic
class BaseLog():
class BaseLog:
"""Mixin class with the shared columns/relationships for log classes."""
@declared_attr
@ -34,7 +26,7 @@ class BaseLog():
@declared_attr
def user_id(self) -> Column:
"""Return the user_id column."""
return Column(Integer, ForeignKey('users.user_id'), index=True)
return Column(Integer, ForeignKey("users.user_id"), index=True)
@declared_attr
def event_type(self) -> Column:
@ -53,7 +45,7 @@ class BaseLog():
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
@declared_attr
@ -64,27 +56,26 @@ class BaseLog():
@declared_attr
def user(self) -> Any:
"""Return the user relationship."""
return relationship('User', lazy=False)
return relationship("User", lazy=False)
class Log(DatabaseModel, BaseLog):
"""Model for a basic log entry."""
__tablename__ = 'log'
__tablename__ = "log"
INHERITED_TABLES = ['log_topics']
INHERITED_TABLES = ["log_topics"]
def __init__(
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Create a new log entry.
User and IP address info is extracted from the Request object.
`info` is an optional dict of arbitrary data that will be stored in
JSON form.
User and IP address info is extracted from the Request object. `info` is an
optional dict of arbitrary data that will be stored in JSON form.
"""
self.user = request.user
self.event_type = event_type
@ -97,19 +88,20 @@ class Log(DatabaseModel, BaseLog):
class LogTopic(DatabaseModel, BaseLog):
"""Model for a log entry related to a specific topic."""
__tablename__ = 'log_topics'
__tablename__ = "log_topics"
topic_id: int = Column(
Integer, ForeignKey('topics.topic_id'), index=True, nullable=False)
Integer, ForeignKey("topics.topic_id"), index=True, nullable=False
)
topic: Topic = relationship('Topic')
topic: Topic = relationship("Topic")
def __init__(
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
self,
event_type: LogEventType,
request: Request,
topic: Topic,
info: Optional[Dict[str, Any]] = None,
) -> None:
"""Create a new log entry related to a specific topic."""
# pylint: disable=non-parent-init-called
@ -122,57 +114,55 @@ class LogTopic(DatabaseModel, BaseLog):
if self.event_type == LogEventType.TOPIC_TAG:
return self._tag_event_description()
elif self.event_type == LogEventType.TOPIC_MOVE:
old_group = self.info['old'] # noqa
new_group = self.info['new'] # noqa
return f'moved from ~{old_group} to ~{new_group}'
old_group = self.info["old"] # noqa
new_group = self.info["new"] # noqa
return f"moved from ~{old_group} to ~{new_group}"
elif self.event_type == LogEventType.TOPIC_LOCK:
return 'locked comments'
return "locked comments"
elif self.event_type == LogEventType.TOPIC_UNLOCK:
return 'unlocked comments'
return "unlocked comments"
elif self.event_type == LogEventType.TOPIC_TITLE_EDIT:
old_title = self.info['old'] # noqa
new_title = self.info['new'] # noqa
old_title = self.info["old"] # noqa
new_title = self.info["new"] # noqa
return f'changed title from "{old_title}" to "{new_title}"'
return f'performed action {self.event_type.name}' # noqa
return f"performed action {self.event_type.name}" # noqa
def _tag_event_description(self) -> str:
"""Return a description of a TOPIC_TAG event as a string."""
if self.event_type != LogEventType.TOPIC_TAG:
raise TypeError
old_tags = set(self.info['old']) # noqa
new_tags = set(self.info['new']) # noqa
old_tags = set(self.info["old"]) # noqa
new_tags = set(self.info["new"]) # noqa
added_tags = new_tags - old_tags
removed_tags = old_tags - new_tags
description = ''
description = ""
if added_tags:
tag_str = ', '.join([f"'{tag}'" for tag in added_tags])
tag_str = ", ".join([f"'{tag}'" for tag in added_tags])
if len(added_tags) == 1:
description += f'added tag {tag_str}'
description += f"added tag {tag_str}"
else:
description += f'added tags {tag_str}'
description += f"added tags {tag_str}"
if removed_tags:
description += ' and '
description += " and "
if removed_tags:
tag_str = ', '.join([f"'{tag}'" for tag in removed_tags])
tag_str = ", ".join([f"'{tag}'" for tag in removed_tags])
if len(removed_tags) == 1:
description += f'removed tag {tag_str}'
description += f"removed tag {tag_str}"
else:
description += f'removed tags {tag_str}'
description += f"removed tags {tag_str}"
return description
@event.listens_for(Log.__table__, 'after_create')
@event.listens_for(Log.__table__, "after_create")
def create_inherited_tables(
target: Table,
connection: Connection,
**kwargs: Any,
target: Table, connection: Connection, **kwargs: Any
) -> None:
"""Create all the tables that inherit from the base "log" one."""
# pylint: disable=unused-argument
@ -180,43 +170,39 @@ def create_inherited_tables(
# log_topics
connection.execute(
'CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)')
"CREATE TABLE log_topics (topic_id integer not null) INHERITS (log)"
)
fk_name = naming['fk'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
'referred_table_name': 'topics',
fk_name = naming["fk"] % {
"table_name": "log_topics",
"column_0_name": "topic_id",
"referred_table_name": "topics",
}
connection.execute(
f'ALTER TABLE log_topics ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (topic_id) REFERENCES topics (topic_id)'
f"ALTER TABLE log_topics ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (topic_id) REFERENCES topics (topic_id)"
)
ix_name = naming['ix'] % {
'table_name': 'log_topics',
'column_0_name': 'topic_id',
}
connection.execute(f'CREATE INDEX {ix_name} ON log_topics (topic_id)')
ix_name = naming["ix"] % {"table_name": "log_topics", "column_0_name": "topic_id"}
connection.execute(f"CREATE INDEX {ix_name} ON log_topics (topic_id)")
# duplicate all the indexes/constraints from the base log table
for table in Log.INHERITED_TABLES:
pk_name = naming['pk'] % {'table_name': table}
pk_name = naming["pk"] % {"table_name": table}
connection.execute(
f'ALTER TABLE {table} '
f'ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {pk_name} PRIMARY KEY (log_id)"
)
for col in ('event_time', 'event_type', 'ip_address', 'user_id'):
ix_name = naming['ix'] % {
'table_name': table, 'column_0_name': col}
connection.execute(f'CREATE INDEX {ix_name} ON {table} ({col})')
for col in ("event_time", "event_type", "ip_address", "user_id"):
ix_name = naming["ix"] % {"table_name": table, "column_0_name": col}
connection.execute(f"CREATE INDEX {ix_name} ON {table} ({col})")
fk_name = naming['fk'] % {
'table_name': table,
'column_0_name': 'user_id',
'referred_table_name': 'users',
fk_name = naming["fk"] % {
"table_name": table,
"column_0_name": "user_id",
"referred_table_name": "users",
}
connection.execute(
f'ALTER TABLE {table} ADD CONSTRAINT {fk_name} '
'FOREIGN KEY (user_id) REFERENCES users (user_id)'
f"ALTER TABLE {table} ADD CONSTRAINT {fk_name} "
"FOREIGN KEY (user_id) REFERENCES users (user_id)"
)

136
tildes/tildes/models/message/message.py

@ -1,13 +1,12 @@
"""Contains the MessageConversation and MessageReply classes.
Note the difference between these two classes - MessageConversation represents
both the overall conversation and the initial message in a particular message
conversation/thread. Subsequent replies (if any) inside that same conversation
are represented by MessageReply.
This might feel a bit unusual since it splits "all messages" across two
tables/classes, but it simplifies a lot of things when organizing them into
threads.
Note the difference between these two classes - MessageConversation represents both the
overall conversation and the initial message in a particular message
conversation/thread. Subsequent replies (if any) inside that same conversation are
represented by MessageReply.
This might feel a bit unusual since it splits "all messages" across two tables/classes,
but it simplifies a lot of things when organizing them into threads.
"""
from datetime import datetime
@ -44,89 +43,81 @@ class MessageConversation(DatabaseModel):
Trigger behavior:
Incoming:
- num_replies, last_reply_time, and unread_user_ids are updated when a
new message_replies row is inserted for the conversation.
- num_replies and last_reply_time will be updated if a message_replies
row is deleted.
- num_replies, last_reply_time, and unread_user_ids are updated when a new
message_replies row is inserted for the conversation.
- num_replies and last_reply_time will be updated if a message_replies row is
deleted.
Outgoing:
- Inserting or updating unread_user_ids will update num_unread_messages
for all relevant users.
- Inserting or updating unread_user_ids will update num_unread_messages for all
relevant users.
"""
schema_class = MessageConversationSchema
__tablename__ = 'message_conversations'
__tablename__ = "message_conversations"
conversation_id: int = Column(Integer, primary_key=True)
sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
recipient_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
subject: str = Column(
Text,
CheckConstraint(
f'LENGTH(subject) <= {SUBJECT_MAX_LENGTH}',
name='subject_length',
f"LENGTH(subject) <= {SUBJECT_MAX_LENGTH}", name="subject_length"
),
nullable=False,
)
markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
num_replies: int = Column(Integer, nullable=False, server_default='0')
last_reply_time: Optional[datetime] = Column(
TIMESTAMP(timezone=True), index=True)
num_replies: int = Column(Integer, nullable=False, server_default="0")
last_reply_time: Optional[datetime] = Column(TIMESTAMP(timezone=True), index=True)
unread_user_ids: List[int] = Column(
ARRAY(Integer), nullable=False, server_default='{}')
ARRAY(Integer), nullable=False, server_default="{}"
)
sender: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[sender_id])
"User", lazy=False, innerjoin=True, foreign_keys=[sender_id]
)
recipient: User = relationship(
'User', lazy=False, innerjoin=True, foreign_keys=[recipient_id])
replies: Sequence['MessageReply'] = relationship(
'MessageReply', order_by='MessageReply.created_time')
"User", lazy=False, innerjoin=True, foreign_keys=[recipient_id]
)
replies: Sequence["MessageReply"] = relationship(
"MessageReply", order_by="MessageReply.created_time"
)
# Create a GIN index on the unread_user_ids column using the gin__int_ops
# operator class supplied by the intarray module. This should be the best
# index for "array contains" queries.
# Create a GIN index on the unread_user_ids column using the gin__int_ops operator
# class supplied by the intarray module. This should be the best index for "array
# contains" queries.
__table_args__ = (
Index(
'ix_message_conversations_unread_user_ids_gin',
"ix_message_conversations_unread_user_ids_gin",
unread_user_ids,
postgresql_using='gin',
postgresql_ops={'unread_user_ids': 'gin__int_ops'},
postgresql_using="gin",
postgresql_ops={"unread_user_ids": "gin__int_ops"},
),
)
def __init__(
self,
sender: User,
recipient: User,
subject: str,
markdown: str,
self, sender: User, recipient: User, subject: str, markdown: str
) -> None:
"""Create a new message conversation between two users."""
self.sender_id = sender.user_id
self.recipient_id = recipient.user_id
self.unread_user_ids = ([self.recipient_id])
self.unread_user_ids = [self.recipient_id]
self.subject = subject
self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='conversation')
incr_counter("messages", type="conversation")
def __acl__(self) -> Sequence[Tuple[str, Any, str]]:
"""Pyramid security ACL."""
@ -159,11 +150,11 @@ class MessageConversation(DatabaseModel):
def other_user(self, viewer: User) -> User:
"""Return the conversation's other user from viewer's perspective.
That is, if the viewer is the sender, this will be the recipient, and
vice versa.
That is, if the viewer is the sender, this will be the recipient, and vice
versa.
"""
if not self.is_participant(viewer):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
if viewer == self.sender:
return self.recipient
@ -173,35 +164,36 @@ class MessageConversation(DatabaseModel):
def is_unread_by_user(self, user: User) -> bool:
"""Return whether the conversation is unread by the specified user."""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
return user.user_id in self.unread_user_ids
def mark_unread_for_user(self, user: User) -> None:
"""Mark the conversation unread for the specified user.
Uses the postgresql intarray union operator `|`, so there's no need to
worry about duplicate values, race conditions, etc.
Uses the postgresql intarray union operator `|`, so there's no need to worry
about duplicate values, race conditions, etc.
"""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
union = MessageConversation.unread_user_ids.op('|') # type: ignore
union = MessageConversation.unread_user_ids.op("|") # type: ignore
self.unread_user_ids = union(user.user_id)
def mark_read_for_user(self, user: User) -> None:
"""Mark the conversation read for the specified user.
Uses the postgresql intarray "remove element from array" operation, so
there's no need to worry about whether the value is present or not,
race conditions, etc.
Uses the postgresql intarray "remove element from array" operation, so there's
no need to worry about whether the value is present or not, race conditions,
etc.
"""
if not self.is_participant(user):
raise ValueError('User is not a participant in this conversation.')
raise ValueError("User is not a participant in this conversation.")
user_id = user.user_id
self.unread_user_ids = ( # type: ignore
MessageConversation.unread_user_ids - user_id) # type: ignore
MessageConversation.unread_user_ids - user_id # type: ignore
)
class MessageReply(DatabaseModel):
@ -209,45 +201,39 @@ class MessageReply(DatabaseModel):
Trigger behavior:
Outgoing:
- Inserting will update num_replies, last_reply_time, and
unread_user_ids for the relevant conversation.
- Inserting will update num_replies, last_reply_time, and unread_user_ids for
the relevant conversation.
- Deleting will update num_replies and last_reply_time for the relevant
conversation.
"""
schema_class = MessageReplySchema
__tablename__ = 'message_replies'
__tablename__ = "message_replies"
reply_id: int = Column(Integer, primary_key=True)
conversation_id: int = Column(
Integer,
ForeignKey('message_conversations.conversation_id'),
ForeignKey("message_conversations.conversation_id"),
nullable=False,
index=True,
)
sender_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
markdown: str = deferred(Column(Text, nullable=False))
rendered_html: str = Column(Text, nullable=False)
sender: User = relationship('User', lazy=False, innerjoin=True)
sender: User = relationship("User", lazy=False, innerjoin=True)
def __init__(
self,
conversation: MessageConversation,
sender: User,
markdown: str,
self, conversation: MessageConversation, sender: User, markdown: str
) -> None:
"""Add a new reply to a message conversation."""
self.conversation_id = conversation.conversation_id
@ -255,7 +241,7 @@ class MessageReply(DatabaseModel):
self.markdown = markdown
self.rendered_html = convert_markdown_to_safe_html(markdown)
incr_counter('messages', type='reply')
incr_counter("messages", type="reply")
@property
def reply_id36(self) -> str:

75
tildes/tildes/models/model_query.py

@ -8,7 +8,7 @@ from sqlalchemy.orm import Load, undefer
from sqlalchemy.orm.query import Query
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class ModelQuery(Query):
@ -22,105 +22,104 @@ class ModelQuery(Query):
self.request = request
# can only filter deleted items if the table has an 'is_deleted' column
self.filter_deleted = bool('is_deleted' in model_cls.__table__.columns)
self.filter_deleted = bool("is_deleted" in model_cls.__table__.columns)
# can only filter removed items if the table has an 'is_removed' column
self.filter_removed = bool('is_removed' in model_cls.__table__.columns)
self.filter_removed = bool("is_removed" in model_cls.__table__.columns)
def __iter__(self) -> Iterator[ModelType]:
"""Iterate over the (processed) results of the query.
SQLAlchemy goes through __iter__ to execute the query and return the
results, so adding processing here should cover all the possibilities.
SQLAlchemy goes through __iter__ to execute the query and return the results, so
adding processing here should cover all the possibilities.
"""
results = super().__iter__()
return iter([self._process_result(result) for result in results])
def _attach_extra_data(self) -> 'ModelQuery':
def _attach_extra_data(self) -> "ModelQuery":
"""Override to attach extra data to query before execution."""
return self
def _finalize(self) -> 'ModelQuery':
def _finalize(self) -> "ModelQuery":
"""Finalize the query before it's executed."""
# pylint: disable=protected-access
# Assertions are disabled to allow these functions to add more filters
# even though .limit() or .offset() may have already been called. This
# is potentially dangerous, but should be fine with the existing
# straightforward usage patterns.
# Assertions are disabled to allow these functions to add more filters even
# though .limit() or .offset() may have already been called. This is potentially
# dangerous, but should be fine with the existing straightforward usage
# patterns.
return (
self
.enable_assertions(False)
self.enable_assertions(False)
._attach_extra_data()
._filter_deleted_if_necessary()
._filter_removed_if_necessary()
)
def _before_compile_listener(self) -> 'ModelQuery':
def _before_compile_listener(self) -> "ModelQuery":
"""Do any final adjustments to the query before it's compiled.
Note that this method cannot be overridden by subclasses because of
the way it is subscribed to the event. Subclasses should override the
_finalize() method instead if necessary.
Note that this method cannot be overridden by subclasses because of the way it
is subscribed to the event. Subclasses should override the _finalize() method
instead if necessary.
"""
return self._finalize()
def _filter_deleted_if_necessary(self) -> 'ModelQuery':
def _filter_deleted_if_necessary(self) -> "ModelQuery":
"""Filter out deleted rows unless they were explicitly included."""
if not self.filter_deleted:
return self
return self.filter(self.model_cls.is_deleted == False) # noqa
def _filter_removed_if_necessary(self) -> 'ModelQuery':
def _filter_removed_if_necessary(self) -> "ModelQuery":
"""Filter out removed rows unless they were explicitly included."""
if not self.filter_removed:
return self
return self.filter(self.model_cls.is_removed == False) # noqa
def lock_based_on_request_method(self) -> 'ModelQuery':
def lock_based_on_request_method(self) -> "ModelQuery":
"""Lock the rows if request method implies it's needed (generative).
Applying this function to a query will cause the database to acquire
a row-level FOR UPDATE lock on any rows the query retrieves. This is
only done if the request method is DELETE, PATCH, or PUT, which all
imply that the item(s) being fetched are going to be modified.
Applying this function to a query will cause the database to acquire a row-level
FOR UPDATE lock on any rows the query retrieves. This is only done if the
request method is DELETE, PATCH, or PUT, which all imply that the item(s) being
fetched are going to be modified.
Note that POST is specifically not included, because the item being
POSTed to is not usually modified in a "dangerous" way as a result.
Note that POST is specifically not included, because the item being POSTed to is
not usually modified in a "dangerous" way as a result.
"""
if self.request.method in {'DELETE', 'PATCH', 'PUT'}:
if self.request.method in {"DELETE", "PATCH", "PUT"}:
return self.with_for_update(of=self.model_cls)
return self
def include_deleted(self) -> 'ModelQuery':
def include_deleted(self) -> "ModelQuery":
"""Specify that deleted rows should be included (generative)."""
self.filter_deleted = False
return self
def include_removed(self) -> 'ModelQuery':
def include_removed(self) -> "ModelQuery":
"""Specify that removed rows should be included (generative)."""
self.filter_removed = False
return self
def join_all_relationships(self) -> 'ModelQuery':
def join_all_relationships(self) -> "ModelQuery":
"""Eagerly join all lazy relationships (generative).
This is useful for being able to load an item "fully" in a single
query and avoid needing to make additional queries for related items.
This is useful for being able to load an item "fully" in a single query and
avoid needing to make additional queries for related items.
"""
# pylint: disable=no-member
self = self.options(Load(self.model_cls).joinedload('*'))
self = self.options(Load(self.model_cls).joinedload("*"))
return self
def undefer_all_columns(self) -> 'ModelQuery':
def undefer_all_columns(self) -> "ModelQuery":
"""Undefer all columns (generative)."""
self = self.options(undefer('*'))
self = self.options(undefer("*"))
return self
@ -130,11 +129,11 @@ class ModelQuery(Query):
return result
# add a listener so the _finalize() function will be called automatically just
# before the query executes
# add a listener so the _finalize() function will be called automatically just before
# the query executes
event.listen(
ModelQuery,
'before_compile',
"before_compile",
ModelQuery._before_compile_listener, # pylint: disable=protected-access
retval=True,
)

80
tildes/tildes/models/pagination.py

@ -9,7 +9,7 @@ from tildes.lib.id import id_to_id36, id36_to_id
from .model_query import ModelQuery
ModelType = TypeVar('ModelType') # pylint: disable=invalid-name
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
class PaginatedQuery(ModelQuery):
@ -18,7 +18,7 @@ class PaginatedQuery(ModelQuery):
def __init__(self, model_cls: Any, request: Request) -> None:
"""Initialize a PaginatedQuery for the specified model and request."""
if len(model_cls.__table__.primary_key) > 1:
raise TypeError('Only single-col primary key tables are supported')
raise TypeError("Only single-col primary key tables are supported")
super().__init__(model_cls, request)
@ -56,26 +56,24 @@ class PaginatedQuery(ModelQuery):
def is_reversed(self) -> bool:
"""Return whether the query is operating "in reverse".
This is a bit confusing. When moving "forward" through pages, items
will be queried in the same order that they are displayed. For example,
when displaying the newest topics, the query is simply for "newest N
topics" (where N is the number of items per page), with an optional
"after topic X" clause. Either way, the first result from the query
will have the highest created_time, and should be the first item
displayed.
However, things work differently when you are paging "backwards". Since
this is done by looking before a specific item, the query needs to
fetch items in the opposite order of how they will be displayed. For
the "newest" sort example, when paging backwards you need to query for
"*oldest* N items before topic X", so the query ordering is the exact
opposite of the desired display order. The first result from the query
will have the *lowest* created_time, so should be the last item
displayed. Because of this, the results need to be reversed.
This is a bit confusing. When moving "forward" through pages, items will be
queried in the same order that they are displayed. For example, when displaying
the newest topics, the query is simply for "newest N topics" (where N is the
number of items per page), with an optional "after topic X" clause. Either way,
the first result from the query will have the highest created_time, and should
be the first item displayed.
However, things work differently when you are paging "backwards". Since this is
done by looking before a specific item, the query needs to fetch items in the
opposite order of how they will be displayed. For the "newest" sort example,
when paging backwards you need to query for "*oldest* N items before topic X",
so the query ordering is the exact opposite of the desired display order. The
first result from the query will have the *lowest* created_time, so should be
the last item displayed. Because of this, the results need to be reversed.
"""
return bool(self.before_id)
def after_id36(self, id36: str) -> 'PaginatedQuery':
def after_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results after an id36 (generative)."""
if self.before_id:
raise ValueError("Can't set both before and after restrictions")
@ -84,7 +82,7 @@ class PaginatedQuery(ModelQuery):
return self
def before_id36(self, id36: str) -> 'PaginatedQuery':
def before_id36(self, id36: str) -> "PaginatedQuery":
"""Restrict the query to results before an id36 (generative)."""
if self.after_id:
raise ValueError("Can't set both before and after restrictions")
@ -93,27 +91,27 @@ class PaginatedQuery(ModelQuery):
return self
def _apply_before_or_after(self) -> 'PaginatedQuery':
def _apply_before_or_after(self) -> "PaginatedQuery":
"""Apply the "before" or "after" restrictions if necessary."""
if not (self.after_id or self.before_id):
return self
query = self
# determine the ID of the "anchor item" that we're using as an upper or
# lower bound, and which type of bound it is
# determine the ID of the "anchor item" that we're using as an upper or lower
# bound, and which type of bound it is
if self.after_id:
anchor_id = self.after_id
# since we're looking for other items "after" the anchor item, it
# will act as an upper bound when the sort order is descending,
# otherwise it's a lower bound
# since we're looking for other items "after" the anchor item, it will act
# as an upper bound when the sort order is descending, otherwise it's a
# lower bound
is_anchor_upper_bound = self.sort_desc
elif self.before_id:
anchor_id = self.before_id
# opposite of "after" behavior - when looking "before" the anchor
# item, it's an upper bound if the sort order is *ascending*
# opposite of "after" behavior - when looking "before" the anchor item, it's
# an upper bound if the sort order is *ascending*
is_anchor_upper_bound = not self.sort_desc
# create a subquery to get comparison values for the anchor item
@ -132,12 +130,12 @@ class PaginatedQuery(ModelQuery):
return query
def _finalize(self) -> 'PaginatedQuery':
def _finalize(self) -> "PaginatedQuery":
"""Finalize the query before execution."""
query = super()._finalize()
# if the query is reversed, we need to sort in the opposite dir
# (basically self.sort_desc XOR self.is_reversed)
# if the query is reversed, we need to sort in the opposite dir (basically
# self.sort_desc XOR self.is_reversed)
desc = self.sort_desc
if self.is_reversed:
desc = not desc
@ -152,7 +150,7 @@ class PaginatedQuery(ModelQuery):
return query
def get_page(self, per_page: int) -> 'PaginatedResults':
def get_page(self, per_page: int) -> "PaginatedResults":
"""Get a page worth of results from the query (`per page` items)."""
return PaginatedResults(self, per_page)
@ -167,18 +165,18 @@ class PaginatedResults:
"""Fetch results from a PaginatedQuery."""
self.per_page = per_page
# if the query had `before` or `after` restrictions, there must be a
# page in that direction (it's where we came from)
# if the query had `before` or `after` restrictions, there must be a page in
# that direction (it's where we came from)
self.has_next_page = bool(query.before_id)
self.has_prev_page = bool(query.after_id)
# fetch the results - try to get one more than we're actually going to
# display, so that we know if there's another page
# fetch the results - try to get one more than we're actually going to display,
# so that we know if there's another page
self.results = query.limit(per_page + 1).all()
# if we managed to get one more item than the page size, there's
# another page in the same direction that we're going - set the
# relevant attr and remove the extra item so it's not displayed
# if we managed to get one more item than the page size, there's another page in
# the same direction that we're going - set the relevant attr and remove the
# extra item so it's not displayed
if len(self.results) > per_page:
if query.is_reversed:
self.results = self.results[1:]
@ -187,8 +185,8 @@ class PaginatedResults:
self.has_next_page = True
self.results = self.results[:-1]
# if the query came back empty for some reason, we won't be able to
# have next/prev pages since there are no items to base them on
# if the query came back empty for some reason, we won't be able to have
# next/prev pages since there are no items to base them on
if not self.results:
self.has_next_page = False
self.has_prev_page = False

196
tildes/tildes/models/topic/topic.py

@ -31,17 +31,14 @@ from tildes.metrics import incr_counter
from tildes.models import DatabaseModel
from tildes.models.group import Group
from tildes.models.user import User
from tildes.schemas.topic import (
TITLE_MAX_LENGTH,
TopicSchema,
)
from tildes.schemas.topic import TITLE_MAX_LENGTH, TopicSchema
# edits inside this period after creation will not mark the topic as edited
EDIT_GRACE_PERIOD = timedelta(minutes=5)
# special tags to put at the front of the tag list
SPECIAL_TAGS = ['nsfw', 'spoiler']
SPECIAL_TAGS = ["nsfw", "spoiler"]
class Topic(DatabaseModel):
@ -49,91 +46,82 @@ class Topic(DatabaseModel):
Trigger behavior:
Incoming:
- num_votes will be incremented and decremented by insertions and
deletions in topic_votes.
- num_comments will be incremented and decremented by insertions,
deletions, and updates to is_deleted in comments.
- last_activity_time will be updated by insertions, deletions, and
- num_votes will be incremented and decremented by insertions and deletions in
topic_votes.
- num_comments will be incremented and decremented by insertions, deletions, and
updates to is_deleted in comments.
- last_activity_time will be updated by insertions, deletions, and updates to
is_deleted in comments.
Outgoing:
- Inserting a row or updating markdown will send a rabbitmq message
for "topic.created" or "topic.edited" respectively.
- Inserting a row or updating markdown will send a rabbitmq message for
"topic.created" or "topic.edited" respectively.
Internal:
- deleted_time will be set when is_deleted is set to true
"""
schema_class = TopicSchema
__tablename__ = 'topics'
__tablename__ = "topics"
topic_id: int = Column(Integer, primary_key=True)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
index=True,
Integer, ForeignKey("groups.group_id"), nullable=False, index=True
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
last_edited_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
last_activity_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
is_deleted: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
deleted_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
is_removed: bool = Column(
Boolean, nullable=False, server_default='false', index=True)
Boolean, nullable=False, server_default="false", index=True
)
removed_time: Optional[datetime] = Column(TIMESTAMP(timezone=True))
title: str = Column(
Text,
CheckConstraint(
f'LENGTH(title) <= {TITLE_MAX_LENGTH}',
name='title_length',
),
CheckConstraint(f"LENGTH(title) <= {TITLE_MAX_LENGTH}", name="title_length"),
nullable=False,
)
topic_type: TopicType = Column(
ENUM(TopicType), nullable=False, server_default='TEXT')
_markdown: Optional[str] = deferred(Column('markdown', Text))
ENUM(TopicType), nullable=False, server_default="TEXT"
)
_markdown: Optional[str] = deferred(Column("markdown", Text))
rendered_html: Optional[str] = Column(Text)
link: Optional[str] = Column(Text)
content_metadata: Dict[str, Any] = Column(JSONB)
num_comments: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_votes: int = Column(
Integer, nullable=False, server_default='0', index=True)
num_comments: int = Column(Integer, nullable=False, server_default="0", index=True)
num_votes: int = Column(Integer, nullable=False, server_default="0", index=True)
_tags: List[Ltree] = Column(
'tags', ArrayOfLtree, nullable=False, server_default='{}')
is_official: bool = Column(Boolean, nullable=False, server_default='false')
is_locked: bool = Column(Boolean, nullable=False, server_default='false')
"tags", ArrayOfLtree, nullable=False, server_default="{}"
)
is_official: bool = Column(Boolean, nullable=False, server_default="false")
is_locked: bool = Column(Boolean, nullable=False, server_default="false")
user: User = relationship('User', lazy=False, innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", lazy=False, innerjoin=True)
group: Group = relationship("Group", innerjoin=True)
# Create a GiST index on the tags column
__table_args__ = (
Index('ix_topics_tags_gist', _tags, postgresql_using='gist'),
)
__table_args__ = (Index("ix_topics_tags_gist", _tags, postgresql_using="gist"),)
@hybrid_property
def markdown(self) -> Optional[str]:
"""Return the topic's markdown."""
if not self.is_text_type:
raise AttributeError('Only text topics have markdown')
raise AttributeError("Only text topics have markdown")
return self._markdown
@ -141,7 +129,7 @@ class Topic(DatabaseModel):
def markdown(self, new_markdown: str) -> None:
"""Set the topic's markdown and render its HTML."""
if not self.is_text_type:
raise AttributeError('Can only set markdown for text topics')
raise AttributeError("Can only set markdown for text topics")
if new_markdown == self.markdown:
return
@ -149,21 +137,19 @@ class Topic(DatabaseModel):
self._markdown = new_markdown
self.rendered_html = convert_markdown_to_safe_html(new_markdown)
if (self.created_time and
utc_now() - self.created_time > EDIT_GRACE_PERIOD):
if self.created_time and utc_now() - self.created_time > EDIT_GRACE_PERIOD:
self.last_edited_time = utc_now()
@hybrid_property
def tags(self) -> List[str]:
"""Return the topic's tags."""
sorted_tags = [str(tag).replace('_', ' ') for tag in self._tags]
sorted_tags = [str(tag).replace("_", " ") for tag in self._tags]
# move special tags in front
# reverse so that tags at the start of the list appear first
for tag in reversed(SPECIAL_TAGS):
if tag in sorted_tags:
sorted_tags.insert(
0, sorted_tags.pop(sorted_tags.index(tag)))
sorted_tags.insert(0, sorted_tags.pop(sorted_tags.index(tag)))
return sorted_tags
@ -176,12 +162,7 @@ class Topic(DatabaseModel):
return f'<Topic "{self.title}" ({self.topic_id})>'
@classmethod
def _create_base_topic(
cls,
group: Group,
author: User,
title: str,
) -> 'Topic':
def _create_base_topic(cls, group: Group, author: User, title: str) -> "Topic":
"""Create the "base" for a new topic."""
new_topic = cls()
new_topic.group_id = group.group_id
@ -192,35 +173,27 @@ class Topic(DatabaseModel):
@classmethod
def create_text_topic(
cls,
group: Group,
author: User,
title: str,
markdown: str = '',
) -> 'Topic':
cls, group: Group, author: User, title: str, markdown: str = ""
) -> "Topic":
"""Create a new text topic."""
new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.TEXT
new_topic.markdown = markdown
incr_counter('topics', type='text')
incr_counter("topics", type="text")
return new_topic
@classmethod
def create_link_topic(
cls,
group: Group,
author: User,
title: str,
link: str,
) -> 'Topic':
cls, group: Group, author: User, title: str, link: str
) -> "Topic":
"""Create a new link topic."""
new_topic = cls._create_base_topic(group, author, title)
new_topic.topic_type = TopicType.LINK
new_topic.link = link
incr_counter('topics', type='link')
incr_counter("topics", type="link")
return new_topic
@ -230,74 +203,74 @@ class Topic(DatabaseModel):
# deleted topics allow "general" viewing, but nothing else
if self.is_deleted:
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
acl.append(DENY_ALL)
# view:
# - everyone gets "general" viewing permission for all topics
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# view_author:
# - removed topics' author is only visible to the author and admins
# - otherwise, everyone can view the author
if self.is_removed:
acl.append((Allow, 'admin', 'view_author'))
acl.append((Allow, self.user_id, 'view_author'))
acl.append((Deny, Everyone, 'view_author'))
acl.append((Allow, "admin", "view_author"))
acl.append((Allow, self.user_id, "view_author"))
acl.append((Deny, Everyone, "view_author"))
acl.append((Allow, Everyone, 'view_author'))
acl.append((Allow, Everyone, "view_author"))
# view_content:
# - removed topics' content is only visible to the author and admins
# - otherwise, everyone can view the content
if self.is_removed:
acl.append((Allow, 'admin', 'view_content'))
acl.append((Allow, self.user_id, 'view_content'))
acl.append((Deny, Everyone, 'view_content'))
acl.append((Allow, "admin", "view_content"))
acl.append((Allow, self.user_id, "view_content"))
acl.append((Deny, Everyone, "view_content"))
acl.append((Allow, Everyone, 'view_content'))
acl.append((Allow, Everyone, "view_content"))
# vote:
# - removed topics can't be voted on by anyone
# - otherwise, logged-in users except the author can vote
if self.is_removed:
acl.append((Deny, Everyone, 'vote'))
acl.append((Deny, Everyone, "vote"))
acl.append((Deny, self.user_id, 'vote'))
acl.append((Allow, Authenticated, 'vote'))
acl.append((Deny, self.user_id, "vote"))
acl.append((Allow, Authenticated, "vote"))
# comment:
# - removed topics can only be commented on by admins
# - locked topics can only be commented on by admins
# - otherwise, logged-in users can comment
if self.is_removed:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
if self.is_locked:
acl.append((Allow, 'admin', 'comment'))
acl.append((Deny, Everyone, 'comment'))
acl.append((Allow, "admin", "comment"))
acl.append((Deny, Everyone, "comment"))
acl.append((Allow, Authenticated, 'comment'))
acl.append((Allow, Authenticated, "comment"))
# edit:
# - only text topics can be edited, only by the author
if self.is_text_type:
acl.append((Allow, self.user_id, 'edit'))
acl.append((Allow, self.user_id, "edit"))
# delete:
# - only the author can delete
acl.append((Allow, self.user_id, 'delete'))
acl.append((Allow, self.user_id, "delete"))
# tag:
# - only the author and admins can tag topics
acl.append((Allow, self.user_id, 'tag'))
acl.append((Allow, 'admin', 'tag'))
acl.append((Allow, self.user_id, "tag"))
acl.append((Allow, "admin", "tag"))
# admin tools
acl.append((Allow, 'admin', 'lock'))
acl.append((Allow, 'admin', 'move'))
acl.append((Allow, 'admin', 'edit_title'))
acl.append((Allow, "admin", "lock"))
acl.append((Allow, "admin", "move"))
acl.append((Allow, "admin", "edit_title"))
acl.append(DENY_ALL)
@ -316,7 +289,7 @@ class Topic(DatabaseModel):
@property
def permalink(self) -> str:
"""Return the permalink for this topic."""
return f'/~{self.group.path}/{self.topic_id36}/{self.url_slug}'
return f"/~{self.group.path}/{self.topic_id36}/{self.url_slug}"
@property
def is_text_type(self) -> bool:
@ -332,33 +305,32 @@ class Topic(DatabaseModel):
def type_for_display(self) -> str:
"""Return a string of the topic's type, suitable for display."""
if self.is_text_type:
return 'Text'
return "Text"
elif self.is_link_type:
return 'Link'
return "Link"
return 'Topic'
return "Topic"
@property
def link_domain(self) -> str:
"""Return the link's domain (for link topics only)."""
if not self.is_link_type or not self.link:
raise ValueError('Non-link topics do not have a domain')
raise ValueError("Non-link topics do not have a domain")
# get the domain from the content metadata if possible, but fall back
# to just parsing it from the link if it's not present
return (self.get_content_metadata('domain')
or get_domain_from_url(self.link))
# get the domain from the content metadata if possible, but fall back to just
# parsing it from the link if it's not present
return self.get_content_metadata("domain") or get_domain_from_url(self.link)
@property
def is_spoiler(self) -> bool:
"""Return whether the topic is marked as a spoiler."""
return 'spoiler' in self.tags
return "spoiler" in self.tags
def get_content_metadata(self, key: str) -> Any:
"""Get a piece of content metadata "safely".
Will return None if the topic has no metadata defined, if this key
doesn't exist in the metadata, etc.
Will return None if the topic has no metadata defined, if this key doesn't exist
in the metadata, etc.
"""
if not isinstance(self.content_metadata, dict):
return None
@ -371,13 +343,13 @@ class Topic(DatabaseModel):
metadata_strings = []
if self.is_text_type:
word_count = self.get_content_metadata('word_count')
word_count = self.get_content_metadata("word_count")
if word_count is not None:
if word_count == 1:
metadata_strings.append('1 word')
metadata_strings.append("1 word")
else:
metadata_strings.append(f'{word_count} words')
metadata_strings.append(f"{word_count} words")
elif self.is_link_type:
metadata_strings.append(f'{self.link_domain}')
metadata_strings.append(f"{self.link_domain}")
return ', '.join(metadata_strings)
return ", ".join(metadata_strings)

65
tildes/tildes/models/topic/topic_query.py

@ -21,14 +21,14 @@ class TopicQuery(PaginatedQuery):
def __init__(self, request: Request) -> None:
"""Initialize a TopicQuery for the request.
If the user is logged in, additional user-specific data will be fetched
along with the topics. For the moment, this is whether the user has
voted on the topics, and data related to their last visit - what time
they last visited, and how many new comments have been posted since.
If the user is logged in, additional user-specific data will be fetched along
with the topics. For the moment, this is whether the user has voted on the
topics, and data related to their last visit - what time they last visited, and
how many new comments have been posted since.
"""
super().__init__(Topic, request)
def _attach_extra_data(self) -> 'TopicQuery':
def _attach_extra_data(self) -> "TopicQuery":
"""Attach the extra user data to the query."""
if not self.request.user:
return self
@ -36,7 +36,7 @@ class TopicQuery(PaginatedQuery):
# pylint: disable=protected-access
return self._attach_vote_data()._attach_visit_data()
def _attach_vote_data(self) -> 'TopicQuery':
def _attach_vote_data(self) -> "TopicQuery":
"""Add a subquery to include whether the user has voted."""
vote_subquery = (
self.request.query(TopicVote)
@ -45,24 +45,25 @@ class TopicQuery(PaginatedQuery):
TopicVote.user == self.request.user,
)
.exists()
.label('user_voted')
.label("user_voted")
)
return self.add_columns(vote_subquery)
def _attach_visit_data(self) -> 'TopicQuery':
def _attach_visit_data(self) -> "TopicQuery":
"""Join the data related to the user's last visit to the topic(s)."""
if self.request.user.track_comment_visits:
query = self.outerjoin(TopicVisit, and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
))
query = query.add_columns(
TopicVisit.visit_time, TopicVisit.num_comments)
query = self.outerjoin(
TopicVisit,
and_(
TopicVisit.topic_id == Topic.topic_id,
TopicVisit.user == self.request.user,
),
)
query = query.add_columns(TopicVisit.visit_time, TopicVisit.num_comments)
else:
# if the user has the feature disabled, just add literal NULLs
query = self.add_columns(
null().label('visit_time'),
null().label('num_comments'),
null().label("visit_time"), null().label("num_comments")
)
return query
@ -90,10 +91,8 @@ class TopicQuery(PaginatedQuery):
return topic
def apply_sort_option(
self,
sort: TopicSortOption,
desc: bool = True,
) -> 'TopicQuery':
self, sort: TopicSortOption, desc: bool = True
) -> "TopicQuery":
"""Apply a TopicSortOption sorting method (generative)."""
if sort == TopicSortOption.VOTES:
self._sort_column = Topic.num_votes
@ -108,22 +107,20 @@ class TopicQuery(PaginatedQuery):
return self
def inside_groups(self, groups: Sequence[Group]) -> 'TopicQuery':
def inside_groups(self, groups: Sequence[Group]) -> "TopicQuery":
"""Restrict the topics to inside specific groups (generative)."""
query_paths = [group.path for group in groups]
subgroup_subquery = (
self.request.db_session.query(Group.group_id)
.filter(Group.path.descendant_of(query_paths))
subgroup_subquery = self.request.db_session.query(Group.group_id).filter(
Group.path.descendant_of(query_paths)
)
return self.filter(
Topic.group_id.in_(subgroup_subquery)) # type: ignore
return self.filter(Topic.group_id.in_(subgroup_subquery)) # type: ignore
def inside_time_period(self, period: SimpleHoursPeriod) -> 'TopicQuery':
def inside_time_period(self, period: SimpleHoursPeriod) -> "TopicQuery":
"""Restrict the topics to inside a time period (generative)."""
# if the time period is too long, this will crash by creating a
# datetime outside the valid range - catch that and just don't filter
# by time period at all if the range is that large
# if the time period is too long, this will crash by creating a datetime outside
# the valid range - catch that and just don't filter by time period at all if
# the range is that large
try:
start_time = utc_now() - period.timedelta
except OverflowError:
@ -131,11 +128,11 @@ class TopicQuery(PaginatedQuery):
return self.filter(Topic.created_time > start_time)
def has_tag(self, tag: Ltree) -> 'TopicQuery':
def has_tag(self, tag: Ltree) -> "TopicQuery":
"""Restrict the topics to ones with a specific tag (generative)."""
# casting tag to string really shouldn't be necessary, but some kind of
# strange interaction seems to be happening with the ArrayOfLtree
# class, this will need some investigation
# casting tag to string really shouldn't be necessary, but some kind of strange
# interaction seems to be happening with the ArrayOfLtree class, this will need
# some investigation
tag = str(tag)
# pylint: disable=protected-access

40
tildes/tildes/models/topic/topic_visit.py

@ -16,40 +16,31 @@ from .topic import Topic
class TopicVisit(DatabaseModel):
"""Model for a user's visit to a topic.
New visits should not be created through __init__(), but by executing the
statement returned by the `generate_insert_statement` method. This will
take advantage of postgresql's ability to update any existing visit.
New visits should not be created through __init__(), but by executing the statement
returned by the `generate_insert_statement` method. This will take advantage of
postgresql's ability to update any existing visit.
Trigger behavior:
Incoming:
- num_comments will be incremented for the author's topic visit when
they post a comment in that topic.
- num_comments will be decremented when a comment is deleted, for all
visits to the topic that were after it was posted.
- num_comments will be incremented for the author's topic visit when they post a
comment in that topic.
- num_comments will be decremented when a comment is deleted, for all visits to
the topic that were after it was posted.
"""
__tablename__ = 'topic_visits'
__tablename__ = "topic_visits"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
)
visit_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
)
visit_time: datetime = Column(TIMESTAMP(timezone=True), nullable=False)
num_comments: int = Column(Integer, nullable=False)
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
@classmethod
def generate_insert_statement(cls, user: User, topic: Topic) -> Insert:
@ -65,9 +56,6 @@ class TopicVisit(DatabaseModel):
)
.on_conflict_do_update(
constraint=cls.__table__.primary_key,
set_={
'visit_time': visit_time,
'num_comments': topic.num_comments,
},
set_={"visit_time": visit_time, "num_comments": topic.num_comments},
)
)

24
tildes/tildes/models/topic/topic_vote.py

@ -17,37 +17,31 @@ class TopicVote(DatabaseModel):
Trigger behavior:
Outgoing:
- Inserting or deleting a row will increment or decrement the num_votes
column for the relevant topic.
- Inserting or deleting a row will increment or decrement the num_votes column
for the relevant topic.
"""
__tablename__ = 'topic_votes'
__tablename__ = "topic_votes"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
topic_id: int = Column(
Integer,
ForeignKey('topics.topic_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("topics.topic_id"), nullable=False, primary_key=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
user: User = relationship('User', innerjoin=True)
topic: Topic = relationship('Topic', innerjoin=True)
user: User = relationship("User", innerjoin=True)
topic: Topic = relationship("Topic", innerjoin=True)
def __init__(self, user: User, topic: Topic) -> None:
"""Create a new vote on a topic."""
self.user = user
self.topic = topic
incr_counter('votes', target_type='topic')
incr_counter("votes", target_type="topic")

87
tildes/tildes/models/user/user.py

@ -39,17 +39,15 @@ class User(DatabaseModel):
Trigger behavior:
Incoming:
- num_unread_notifications will be incremented and decremented by
insertions, deletions, and updates to is_unread in
comment_notifications.
- num_unread_messages will be incremented and decremented by
insertions, deletions, and updates to unread_user_ids in
message_conversations.
- num_unread_notifications will be incremented and decremented by insertions,
deletions, and updates to is_unread in comment_notifications.
- num_unread_messages will be incremented and decremented by insertions,
deletions, and updates to unread_user_ids in message_conversations.
"""
schema_class = UserSchema
__tablename__ = 'users'
__tablename__ = "users"
user_id: int = Column(Integer, primary_key=True)
username: str = Column(CIText, nullable=False, unique=True)
@ -59,9 +57,8 @@ class User(DatabaseModel):
Column(
Text,
CheckConstraint(
'LENGTH(email_address_note) <= '
f'{EMAIL_ADDRESS_NOTE_MAX_LENGTH}',
name='email_address_note_length',
f"LENGTH(email_address_note) <= {EMAIL_ADDRESS_NOTE_MAX_LENGTH}",
name="email_address_note_length",
),
)
)
@ -69,46 +66,36 @@ class User(DatabaseModel):
TIMESTAMP(timezone=True),
nullable=False,
index=True,
server_default=text('NOW()'),
server_default=text("NOW()"),
)
num_unread_messages: int = Column(
Integer, nullable=False, server_default='0')
num_unread_notifications: int = Column(
Integer, nullable=False, server_default='0')
inviter_id: int = Column(Integer, ForeignKey('users.user_id'))
invite_codes_remaining: int = Column(
Integer, nullable=False, server_default='0')
track_comment_visits: bool = Column(
Boolean, nullable=False, server_default='false')
num_unread_messages: int = Column(Integer, nullable=False, server_default="0")
num_unread_notifications: int = Column(Integer, nullable=False, server_default="0")
inviter_id: int = Column(Integer, ForeignKey("users.user_id"))
invite_codes_remaining: int = Column(Integer, nullable=False, server_default="0")
track_comment_visits: bool = Column(Boolean, nullable=False, server_default="false")
auto_mark_notifications_read: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_external: bool = Column(
Boolean, nullable=False, server_default='false')
Boolean, nullable=False, server_default="false"
)
open_new_tab_internal: bool = Column(
Boolean, nullable=False, server_default='false')
open_new_tab_text: bool = Column(
Boolean, nullable=False, server_default='false')
theme_account_default: str = Column(
Text, nullable=False, server_default='')
is_banned: bool = Column(Boolean, nullable=False, server_default='false')
is_admin: bool = Column(Boolean, nullable=False, server_default='false')
home_default_order: Optional[TopicSortOption] = Column(
ENUM(TopicSortOption))
Boolean, nullable=False, server_default="false"
)
open_new_tab_text: bool = Column(Boolean, nullable=False, server_default="false")
theme_account_default: str = Column(Text, nullable=False, server_default="")
is_banned: bool = Column(Boolean, nullable=False, server_default="false")
is_admin: bool = Column(Boolean, nullable=False, server_default="false")
home_default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
home_default_period: Optional[str] = Column(Text)
_filtered_topic_tags: List[Ltree] = Column(
'filtered_topic_tags',
ArrayOfLtree,
nullable=False,
server_default='{}',
"filtered_topic_tags", ArrayOfLtree, nullable=False, server_default="{}"
)
@hybrid_property
def filtered_topic_tags(self) -> List[str]:
"""Return the user's list of filtered topic tags."""
return [
str(tag).replace('_', ' ')
for tag in self._filtered_topic_tags
]
return [str(tag).replace("_", " ") for tag in self._filtered_topic_tags]
@filtered_topic_tags.setter # type: ignore
def filtered_topic_tags(self, new_tags: List[str]) -> None:
@ -116,7 +103,7 @@ class User(DatabaseModel):
def __repr__(self) -> str:
"""Display the user's username and ID as its repr format."""
return f'<User {self.username} ({self.user_id})>'
return f"<User {self.username} ({self.user_id})>"
def __str__(self) -> str:
"""Use the username for the string representation."""
@ -133,12 +120,12 @@ class User(DatabaseModel):
# view:
# - everyone can view all users
acl.append((Allow, Everyone, 'view'))
acl.append((Allow, Everyone, "view"))
# message:
# - anyone can message a user except themself
acl.append((Deny, self.user_id, 'message'))
acl.append((Allow, Authenticated, 'message'))
acl.append((Deny, self.user_id, "message"))
acl.append((Allow, Authenticated, "message"))
# grant the user all other permissions on themself
acl.append((Allow, self.user_id, ALL_PERMISSIONS))
@ -150,13 +137,13 @@ class User(DatabaseModel):
@property
def password(self) -> NoReturn:
"""Return an error since reading the password isn't possible."""
raise AttributeError('Password is write-only')
raise AttributeError("Password is write-only")
@password.setter
def password(self, value: str) -> None:
# need to do manual validation since some password checks depend on
# checking the username at the same time (for similarity)
self.schema.validate({'username': self.username, 'password': value})
# need to do manual validation since some password checks depend on checking the
# username at the same time (for similarity)
self.schema.validate({"username": self.username, "password": value})
self.password_hash = hash_string(value)
@ -167,10 +154,10 @@ class User(DatabaseModel):
def change_password(self, old_password: str, new_password: str) -> None:
"""Change the user's password from the old one to a new one."""
if not self.is_correct_password(old_password):
raise ValueError('Old password was not correct')
raise ValueError("Old password was not correct")
if new_password == old_password:
raise ValueError('New password is the same as old password')
raise ValueError("New password is the same as old password")
# disable mypy on this line because it doesn't handle setters correctly
self.password = new_password # type: ignore
@ -178,7 +165,7 @@ class User(DatabaseModel):
@property
def email_address(self) -> NoReturn:
"""Return an error since reading the email address isn't possible."""
raise AttributeError('Email address is write-only')
raise AttributeError("Email address is write-only")
@email_address.setter
def email_address(self, value: Optional[str]) -> None:

16
tildes/tildes/models/user/user_group_settings.py

@ -15,22 +15,16 @@ from tildes.models.user import User
class UserGroupSettings(DatabaseModel):
"""Model for a user's settings related to a specific group."""
__tablename__ = 'user_group_settings'
__tablename__ = "user_group_settings"
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("users.user_id"), nullable=False, primary_key=True
)
group_id: int = Column(
Integer,
ForeignKey('groups.group_id'),
nullable=False,
primary_key=True,
Integer, ForeignKey("groups.group_id"), nullable=False, primary_key=True
)
default_order: Optional[TopicSortOption] = Column(ENUM(TopicSortOption))
default_period: Optional[str] = Column(Text)
user: User = relationship('User', innerjoin=True)
group: Group = relationship('Group', innerjoin=True)
user: User = relationship("User", innerjoin=True)
group: Group = relationship("Group", innerjoin=True)

45
tildes/tildes/models/user/user_invite_code.py

@ -4,14 +4,7 @@ from datetime import datetime
import random
import string
from sqlalchemy import (
CheckConstraint,
Column,
ForeignKey,
Integer,
Text,
TIMESTAMP,
)
from sqlalchemy import CheckConstraint, Column, ForeignKey, Integer, Text, TIMESTAMP
from sqlalchemy.sql.expression import text
from tildes.models import DatabaseModel
@ -21,7 +14,7 @@ from .user import User
class UserInviteCode(DatabaseModel):
"""Model for invite codes that allow new users to register."""
__tablename__ = 'user_invite_codes'
__tablename__ = "user_invite_codes"
# the character set to generate codes using
ALPHABET = string.ascii_uppercase + string.digits
@ -30,33 +23,25 @@ class UserInviteCode(DatabaseModel):
code: str = Column(
Text,
CheckConstraint(
f'LENGTH(code) <= {LENGTH}',
name='code_length',
),
CheckConstraint(f"LENGTH(code) <= {LENGTH}", name="code_length"),
primary_key=True,
)
user_id: int = Column(
Integer,
ForeignKey('users.user_id'),
nullable=False,
index=True,
Integer, ForeignKey("users.user_id"), nullable=False, index=True
)
created_time: datetime = Column(
TIMESTAMP(timezone=True),
nullable=False,
server_default=text('NOW()'),
TIMESTAMP(timezone=True), nullable=False, server_default=text("NOW()")
)
invitee_id: int = Column(Integer, ForeignKey('users.user_id'))
invitee_id: int = Column(Integer, ForeignKey("users.user_id"))
def __str__(self) -> str:
"""Format the code into a more easily readable version."""
formatted = ''
formatted = ""
for count, char in enumerate(self.code):
# add a dash every 5 chars
if count > 0 and count % 5 == 0:
formatted += '-'
formatted += "-"
formatted += char.upper()
@ -65,13 +50,13 @@ class UserInviteCode(DatabaseModel):
def __init__(self, user: User) -> None:
"""Create a new (random) invite code owned by the user.
Note that uniqueness is not confirmed here, so there is the potential
to create duplicate codes (which will fail to commit to the database).
Note that uniqueness is not confirmed here, so there is the potential to create
duplicate codes (which will fail to commit to the database).
"""
self.user_id = user.user_id
code_chars = random.choices(self.ALPHABET, k=self.LENGTH)
self.code = ''.join(code_chars)
self.code = "".join(code_chars)
@classmethod
def prepare_code_for_lookup(cls, code: str) -> str:
@ -79,11 +64,11 @@ class UserInviteCode(DatabaseModel):
# codes are stored in uppercase
code = code.upper()
# remove any characters that aren't in the code alphabet (allows
# dashes, spaces, etc. to be used to make the codes more readable)
code = ''.join(letter for letter in code if letter in cls.ALPHABET)
# remove any characters that aren't in the code alphabet (allows dashes, spaces,
# etc. to be used to make the codes more readable)
code = "".join(letter for letter in code if letter in cls.ALPHABET)
if len(code) > cls.LENGTH:
raise ValueError('Code is longer than the maximum length')
raise ValueError("Code is longer than the maximum length")
return code

12
tildes/tildes/resources/__init__.py

@ -8,17 +8,13 @@ from tildes.models import DatabaseModel, ModelQuery
def get_resource(request: Request, base_query: ModelQuery) -> DatabaseModel:
"""Prepare and execute base query from a root factory, returning result."""
# While the site is private, we don't want to leak information about which
# usernames or groups exist. So we should just always raise a 403 before
# doing a lookup and potentially raising a 404.
# While the site is private, we don't want to leak information about which usernames
# or groups exist. So we should just always raise a 403 before doing a lookup and
# potentially raising a 404.
if not request.user:
raise HTTPForbidden
query = (
base_query
.lock_based_on_request_method()
.join_all_relationships()
)
query = base_query.lock_based_on_request_method().join_all_relationships()
if not request.is_safe_method:
query = query.undefer_all_columns()

20
tildes/tildes/resources/comment.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.comment import CommentSchema
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def comment_by_id36(request: Request, comment_id36: str) -> Comment:
"""Get a comment specified by {comment_id36} in the route (or 404)."""
query = (
@ -25,26 +22,21 @@ def comment_by_id36(request: Request, comment_id36: str) -> Comment:
return get_resource(request, query)
@use_kwargs(
CommentSchema(only=('comment_id36',)),
locations=('matchdict',),
)
@use_kwargs(CommentSchema(only=("comment_id36",)), locations=("matchdict",))
def notification_by_comment_id36(
request: Request,
comment_id36: str,
request: Request, comment_id36: str
) -> CommentNotification:
"""Get a comment notification specified by {comment_id36} in the route.
Looks up a comment notification for the logged-in user with the
{comment_id36} specified in the route.
Looks up a comment notification for the logged-in user with the {comment_id36}
specified in the route.
"""
if not request.user:
raise HTTPForbidden
comment_id = id36_to_id(comment_id36)
query = request.query(CommentNotification).filter_by(
user=request.user,
comment_id=comment_id,
user=request.user, comment_id=comment_id
)
return get_resource(request, query)

19
tildes/tildes/resources/group.py

@ -11,19 +11,18 @@ from tildes.schemas.group import GroupSchema
@use_kwargs(
GroupSchema(only=('path',), context={'fix_path_capitalization': True}),
locations=('matchdict',),
GroupSchema(only=("path",), context={"fix_path_capitalization": True}),
locations=("matchdict",),
)
def group_by_path(request: Request, path: str) -> Group:
"""Get a group specified by {group_path} in the route (or 404)."""
# If loading the specified group path into the GroupSchema changed it, do a
# 301 redirect to the resulting group path. This will happen in cases like
# the original url including capital letters in the group path, where we
# want to redirect to the proper all-lowercase path instead.
if path != request.matchdict['group_path']:
request.matchdict['group_path'] = path
proper_url = request.route_url(
request.matched_route.name, **request.matchdict)
# If loading the specified group path into the GroupSchema changed it, do a 301
# redirect to the resulting group path. This will happen in cases like the original
# url including capital letters in the group path, where we want to redirect to the
# proper all-lowercase path instead.
if path != request.matchdict["group_path"]:
request.matchdict["group_path"] = path
proper_url = request.route_url(request.matched_route.name, **request.matchdict)
raise HTTPMovedPermanently(location=proper_url)

11
tildes/tildes/resources/message.py

@ -10,17 +10,14 @@ from tildes.schemas.message import MessageConversationSchema
@use_kwargs(
MessageConversationSchema(only=('conversation_id36',)),
locations=('matchdict',),
MessageConversationSchema(only=("conversation_id36",)), locations=("matchdict",)
)
def message_conversation_by_id36(
request: Request,
conversation_id36: str,
request: Request, conversation_id36: str
) -> MessageConversation:
"""Get a conversation specified by {conversation_id36} in the route."""
query = (
request.query(MessageConversation)
.filter_by(conversation_id=id36_to_id(conversation_id36))
query = request.query(MessageConversation).filter_by(
conversation_id=id36_to_id(conversation_id36)
)
return get_resource(request, query)

13
tildes/tildes/resources/topic.py

@ -10,10 +10,7 @@ from tildes.resources import get_resource
from tildes.schemas.topic import TopicSchema
@use_kwargs(
TopicSchema(only=('topic_id36',)),
locations=('matchdict',),
)
@use_kwargs(TopicSchema(only=("topic_id36",)), locations=("matchdict",))
def topic_by_id36(request: Request, topic_id36: str) -> Topic:
"""Get a topic specified by {topic_id36} in the route (or 404)."""
query = (
@ -25,10 +22,10 @@ def topic_by_id36(request: Request, topic_id36: str) -> Topic:
topic = get_resource(request, query)
# if there's also a group specified in the route, check that it's the same
# group as the topic was posted in, otherwise redirect to correct group
if 'group_path' in request.matchdict:
path_from_route = request.matchdict['group_path'].lower()
# if there's also a group specified in the route, check that it's the same group as
# the topic was posted in, otherwise redirect to correct group
if "group_path" in request.matchdict:
path_from_route = request.matchdict["group_path"].lower()
if path_from_route != topic.group.path:
raise HTTPFound(topic.permalink)

5
tildes/tildes/resources/user.py

@ -8,10 +8,7 @@ from tildes.resources import get_resource
from tildes.schemas.user import UserSchema
@use_kwargs(
UserSchema(only=('username',)),
locations=('matchdict',),
)
@use_kwargs(UserSchema(only=("username",)), locations=("matchdict",))
def user_by_username(request: Request, username: str) -> User:
"""Get a user specified by {username} in the route or 404 if not found."""
query = request.query(User).filter(User.username == username)

166
tildes/tildes/routes.py

@ -6,10 +6,7 @@ from pyramid.config import Configurator
from pyramid.request import Request
from pyramid.security import Allow, Authenticated
from tildes.resources.comment import (
comment_by_id36,
notification_by_comment_id36,
)
from tildes.resources.comment import comment_by_id36, notification_by_comment_id36
from tildes.resources.group import group_by_path
from tildes.resources.message import message_conversation_by_id36
from tildes.resources.topic import topic_by_id36
@ -18,172 +15,133 @@ from tildes.resources.user import user_by_username
def includeme(config: Configurator) -> None:
"""Set up application routes."""
config.add_route('home', '/')
config.add_route("home", "/")
config.add_route('groups', '/groups')
config.add_route("groups", "/groups")
config.add_route('login', '/login')
config.add_route('logout', '/logout', factory=LoggedInFactory)
config.add_route("login", "/login")
config.add_route("logout", "/logout", factory=LoggedInFactory)
config.add_route('register', '/register')
config.add_route("register", "/register")
config.add_route('group', '/~{group_path}', factory=group_by_path)
config.add_route(
'new_topic', '/~{group_path}/new_topic', factory=group_by_path)
config.add_route("group", "/~{group_path}", factory=group_by_path)
config.add_route("new_topic", "/~{group_path}/new_topic", factory=group_by_path)
config.add_route(
'group_topics', '/~{group_path}/topics', factory=group_by_path)
config.add_route("group_topics", "/~{group_path}/topics", factory=group_by_path)
config.add_route(
'topic', '/~{group_path}/{topic_id36}*title', factory=topic_by_id36)
"topic", "/~{group_path}/{topic_id36}*title", factory=topic_by_id36
)
config.add_route('user', '/user/{username}', factory=user_by_username)
config.add_route("user", "/user/{username}", factory=user_by_username)
config.add_route("notifications", "/notifications", factory=LoggedInFactory)
config.add_route(
'notifications', '/notifications', factory=LoggedInFactory)
config.add_route(
'notifications_unread',
'/notifications/unread',
factory=LoggedInFactory,
"notifications_unread", "/notifications/unread", factory=LoggedInFactory
)
config.add_route('messages', '/messages', factory=LoggedInFactory)
config.add_route(
'messages_sent', '/messages/sent', factory=LoggedInFactory)
config.add_route("messages", "/messages", factory=LoggedInFactory)
config.add_route("messages_sent", "/messages/sent", factory=LoggedInFactory)
config.add_route("messages_unread", "/messages/unread", factory=LoggedInFactory)
config.add_route(
'messages_unread', '/messages/unread', factory=LoggedInFactory)
config.add_route(
'message_conversation',
'/messages/conversations/{conversation_id36}',
"message_conversation",
"/messages/conversations/{conversation_id36}",
factory=message_conversation_by_id36,
)
config.add_route(
'new_message',
'/user/{username}/new_message',
factory=user_by_username,
"new_message", "/user/{username}/new_message", factory=user_by_username
)
config.add_route(
'user_messages',
'/user/{username}/messages',
factory=user_by_username,
"user_messages", "/user/{username}/messages", factory=user_by_username
)
config.add_route('settings', '/settings', factory=LoggedInFactory)
config.add_route("settings", "/settings", factory=LoggedInFactory)
config.add_route(
'settings_account_recovery',
'/settings/account_recovery',
"settings_account_recovery",
"/settings/account_recovery",
factory=LoggedInFactory,
)
config.add_route(
'settings_comment_visits',
'/settings/comment_visits',
factory=LoggedInFactory,
"settings_comment_visits", "/settings/comment_visits", factory=LoggedInFactory
)
config.add_route("settings_filters", "/settings/filters", factory=LoggedInFactory)
config.add_route(
'settings_filters', '/settings/filters', factory=LoggedInFactory)
config.add_route(
'settings_password_change',
'/settings/password_change',
factory=LoggedInFactory,
"settings_password_change", "/settings/password_change", factory=LoggedInFactory
)
config.add_route('invite', '/invite', factory=LoggedInFactory)
config.add_route("invite", "/invite", factory=LoggedInFactory)
# Route to expose metrics to Prometheus
config.add_route('metrics', '/metrics')
config.add_route("metrics", "/metrics")
# Route for Stripe donation processing page (POSTed to from docs site)
config.add_route('donate_stripe', '/donate_stripe')
config.add_route("donate_stripe", "/donate_stripe")
add_intercooler_routes(config)
def add_intercooler_routes(config: Configurator) -> None:
"""Set up all routes for the (internal-use) Intercooler API endpoints."""
def add_ic_route(name: str, path: str, **kwargs: Any) -> None:
"""Add route with intercooler name prefix, base path, header check."""
name = 'ic_' + name
path = '/api/web' + path
config.add_route(
name,
path,
header='X-IC-Request:true',
**kwargs)
name = "ic_" + name
path = "/api/web" + path
config.add_route(name, path, header="X-IC-Request:true", **kwargs)
add_ic_route(
'group_subscribe',
'/group/{group_path}/subscribe',
factory=group_by_path,
"group_subscribe", "/group/{group_path}/subscribe", factory=group_by_path
)
add_ic_route(
'group_user_settings',
'/group/{group_path}/user_settings',
"group_user_settings",
"/group/{group_path}/user_settings",
factory=group_by_path,
)
add_ic_route('topic', '/topics/{topic_id36}', factory=topic_by_id36)
add_ic_route("topic", "/topics/{topic_id36}", factory=topic_by_id36)
add_ic_route(
'topic_comments',
'/topics/{topic_id36}/comments',
factory=topic_by_id36,
)
add_ic_route(
'topic_group', '/topics/{topic_id36}/group', factory=topic_by_id36)
add_ic_route(
'topic_lock', '/topics/{topic_id36}/lock', factory=topic_by_id36)
add_ic_route(
'topic_title', '/topics/{topic_id36}/title', factory=topic_by_id36)
add_ic_route(
'topic_vote', '/topics/{topic_id36}/vote', factory=topic_by_id36)
add_ic_route(
'topic_tags',
'/topics/{topic_id36}/tags',
factory=topic_by_id36,
"topic_comments", "/topics/{topic_id36}/comments", factory=topic_by_id36
)
add_ic_route("topic_group", "/topics/{topic_id36}/group", factory=topic_by_id36)
add_ic_route("topic_lock", "/topics/{topic_id36}/lock", factory=topic_by_id36)
add_ic_route("topic_title", "/topics/{topic_id36}/title", factory=topic_by_id36)
add_ic_route("topic_vote", "/topics/{topic_id36}/vote", factory=topic_by_id36)
add_ic_route("topic_tags", "/topics/{topic_id36}/tags", factory=topic_by_id36)
add_ic_route("comment", "/comments/{comment_id36}", factory=comment_by_id36)
add_ic_route(
'comment', '/comments/{comment_id36}', factory=comment_by_id36)
add_ic_route(
'comment_replies',
'/comments/{comment_id36}/replies',
factory=comment_by_id36,
"comment_replies", "/comments/{comment_id36}/replies", factory=comment_by_id36
)
add_ic_route(
'comment_vote',
'/comments/{comment_id36}/vote',
factory=comment_by_id36,
"comment_vote", "/comments/{comment_id36}/vote", factory=comment_by_id36
)
add_ic_route(
'comment_tag',
'/comments/{comment_id36}/tags/{name}',
factory=comment_by_id36,
"comment_tag", "/comments/{comment_id36}/tags/{name}", factory=comment_by_id36
)
add_ic_route(
'comment_mark_read',
'/comments/{comment_id36}/mark_read',
"comment_mark_read",
"/comments/{comment_id36}/mark_read",
factory=notification_by_comment_id36,
)
add_ic_route(
'message_conversation_replies',
'/messages/conversations/{conversation_id36}/replies',
"message_conversation_replies",
"/messages/conversations/{conversation_id36}/replies",
factory=message_conversation_by_id36,
)
add_ic_route('user', '/user/{username}', factory=user_by_username)
add_ic_route("user", "/user/{username}", factory=user_by_username)
add_ic_route(
'user_filtered_topic_tags',
'/user/{username}/filtered_topic_tags',
"user_filtered_topic_tags",
"/user/{username}/filtered_topic_tags",
factory=user_by_username,
)
add_ic_route(
'user_invite_code',
'/user/{username}/invite_code',
factory=user_by_username,
"user_invite_code", "/user/{username}/invite_code", factory=user_by_username
)
add_ic_route(
'user_default_listing_options',
'/user/{username}/default_listing_options',
"user_default_listing_options",
"/user/{username}/default_listing_options",
factory=user_by_username,
)
@ -191,12 +149,12 @@ def add_intercooler_routes(config: Configurator) -> None:
class LoggedInFactory:
"""Simple class to use as `factory` to restrict routes to logged-in users.
This class can be used when a route should only be accessible to logged-in
users but doesn't already have another factory that would handle that by
checking access to a specific resource (such as a topic or message).
This class can be used when a route should only be accessible to logged-in users but
doesn't already have another factory that would handle that by checking access to a
specific resource (such as a topic or message).
"""
__acl__ = ((Allow, Authenticated, 'view'),)
__acl__ = ((Allow, Authenticated, "view"),)
def __init__(self, request: Request) -> None:
"""Initialize - no-op, but needs to take the request as an arg."""

16
tildes/tildes/schemas/__init__.py

@ -2,14 +2,14 @@
These schemas are currently being used for several purposes:
- Validation of data for models, such as checking the lengths of strings,
ensuring that they match a particular regex pattern, etc. Specific errors
can be generated for any data that is invalid.
- Validation of data for models, such as checking the lengths of strings, ensuring
that they match a particular regex pattern, etc. Specific errors can be generated
for any data that is invalid.
- Similarly, the webargs library uses the schemas to validate pieces of data
coming in via urls, POST data, etc. It can produce errors if the data is
not valid for the purpose it's intended for.
- Similarly, the webargs library uses the schemas to validate pieces of data coming in
via urls, POST data, etc. It can produce errors if the data is not valid for the
purpose it's intended for.
- Serialization of data, which the Pyramid JSON renderer uses to produce
data for the JSON API endpoints.
- Serialization of data, which the Pyramid JSON renderer uses to produce data for the
JSON API endpoints.
"""

49
tildes/tildes/schemas/fields.py

@ -16,12 +16,7 @@ from tildes.lib.string import simplify_string
class Enum(Field):
"""Field for a native Python Enum (or subclasses)."""
def __init__(
self,
enum_class: Type = None,
*args: Any,
**kwargs: Any,
) -> None:
def __init__(self, enum_class: Type = None, *args: Any, **kwargs: Any) -> None:
"""Initialize the field with an optional enum class."""
super().__init__(*args, **kwargs)
self._enum_class = enum_class
@ -33,12 +28,12 @@ class Enum(Field):
def _deserialize(self, value: str, attr: str, data: dict) -> enum.Enum:
"""Deserialize a string to the enum member with that name."""
if not self._enum_class:
raise ValidationError('Cannot deserialize with no enum class.')
raise ValidationError("Cannot deserialize with no enum class.")
try:
return self._enum_class[value.upper()]
except KeyError:
raise ValidationError('Invalid enum member')
raise ValidationError("Invalid enum member")
class ID36(String):
@ -56,25 +51,19 @@ class ShortTimePeriod(Field):
"""
def _deserialize(
self,
value: str,
attr: str,
data: dict,
self, value: str, attr: str, data: dict
) -> Optional[SimpleHoursPeriod]:
"""Deserialize to a SimpleHoursPeriod object."""
if value == 'all':
if value == "all":
return None
try:
return SimpleHoursPeriod.from_short_form(value)
except ValueError:
raise ValidationError('Invalid time period')
raise ValidationError("Invalid time period")
def _serialize(
self,
value: Optional[SimpleHoursPeriod],
attr: str,
obj: object,
self, value: Optional[SimpleHoursPeriod], attr: str, obj: object
) -> Optional[str]:
"""Serialize the value to the "short form" string."""
if not value:
@ -100,11 +89,11 @@ class Markdown(Field):
super()._validate(value)
if value.isspace():
raise ValidationError('Cannot be entirely whitespace.')
raise ValidationError("Cannot be entirely whitespace.")
def _deserialize(self, value: str, attr: str, data: dict) -> str:
"""Deserialize the string, removing carriage returns in the process."""
value = value.replace('\r', '')
value = value.replace("\r", "")
return value
@ -119,8 +108,8 @@ class SimpleString(Field):
These strings should generally not contain any special formatting (such as
markdown), and have problematic whitespace/unicode/etc. removed.
See the simplify_string() function for full details of how these strings
are processed and sanitized.
See the simplify_string() function for full details of how these strings are
processed and sanitized.
"""
@ -145,23 +134,13 @@ class SimpleString(Field):
class Ltree(Field):
"""Field for postgresql ltree type."""
def _serialize(
self,
value: sqlalchemy_utils.Ltree,
attr: str,
obj: object,
) -> str:
def _serialize(self, value: sqlalchemy_utils.Ltree, attr: str, obj: object) -> str:
"""Serialize the Ltree value - use the (string) path."""
return value.path
def _deserialize(
self,
value: str,
attr: str,
data: dict,
) -> sqlalchemy_utils.Ltree:
def _deserialize(self, value: str, attr: str, data: dict) -> sqlalchemy_utils.Ltree:
"""Deserialize a string path to an Ltree object."""
try:
return sqlalchemy_utils.Ltree(value)
except (TypeError, ValueError):
raise ValidationError('Invalid path')
raise ValidationError("Invalid path")

27
tildes/tildes/schemas/group.py

@ -15,11 +15,13 @@ from tildes.schemas.fields import Ltree, SimpleString
# - must end with a number or lowercase letter
# - the middle can contain numbers, lowercase letters, and underscores
# Note: this regex does not contain any length checks, must be done separately
# fmt: off
GROUP_PATH_ELEMENT_VALID_REGEX = re.compile(
'^[a-z0-9]' # start
'([a-z0-9_]*' # middle
'[a-z0-9])?$', # end
"^[a-z0-9]" # start
"([a-z0-9_]*" # middle
"[a-z0-9])?$" # end
)
# fmt: on
SHORT_DESCRIPTION_MAX_LENGTH = 200
@ -27,21 +29,20 @@ SHORT_DESCRIPTION_MAX_LENGTH = 200
class GroupSchema(Schema):
"""Marshmallow schema for groups."""
path = Ltree(required=True, load_from='group_path')
path = Ltree(required=True, load_from="group_path")
created_time = DateTime(dump_only=True)
short_description = SimpleString(
max_length=SHORT_DESCRIPTION_MAX_LENGTH,
allow_none=True,
max_length=SHORT_DESCRIPTION_MAX_LENGTH, allow_none=True
)
@pre_load
def prepare_path(self, data: dict) -> dict:
"""Prepare the path value before it's validated."""
if not self.context.get('fix_path_capitalization'):
if not self.context.get("fix_path_capitalization"):
return data
# path can also be loaded from group_path, so we need to check both
keys = ('path', 'group_path')
keys = ("path", "group_path")
for key in keys:
if key in data and isinstance(data[key], str):
@ -49,17 +50,17 @@ class GroupSchema(Schema):
return data
@validates('path')
@validates("path")
def validate_path(self, value: sqlalchemy_utils.Ltree) -> None:
"""Validate the path field, raising an error if an issue exists."""
# check each element for length and against validity regex
path_elements = value.path.split('.')
path_elements = value.path.split(".")
for element in path_elements:
if len(element) > 256:
raise ValidationError('Path element %s is too long' % element)
raise ValidationError("Path element %s is too long" % element)
if not GROUP_PATH_ELEMENT_VALID_REGEX.match(element):
raise ValidationError('Path element %s is invalid' % element)
raise ValidationError("Path element %s is invalid" % element)
class Meta:
"""Always use strict checking so error handlers are invoked."""
@ -71,7 +72,7 @@ def is_valid_group_path(path: str) -> bool:
"""Return whether the group path is valid or not."""
schema = GroupSchema(partial=True)
try:
schema.validate({'path': path})
schema.validate({"path": path})
except ValidationError:
return False

68
tildes/tildes/schemas/topic.py

@ -4,13 +4,7 @@ import re
import typing
from urllib.parse import urlparse
from marshmallow import (
pre_load,
Schema,
validates,
validates_schema,
ValidationError,
)
from marshmallow import pre_load, Schema, validates, validates_schema, ValidationError
from marshmallow.fields import DateTime, List, Nested, String, URL
import sqlalchemy_utils
@ -29,7 +23,7 @@ class TopicSchema(Schema):
topic_type = Enum(dump_only=True)
markdown = Markdown(allow_none=True)
rendered_html = String(dump_only=True)
link = URL(schemes={'http', 'https'}, allow_none=True)
link = URL(schemes={"http", "https"}, allow_none=True)
created_time = DateTime(dump_only=True)
tags = List(Ltree())
@ -39,22 +33,22 @@ class TopicSchema(Schema):
@pre_load
def prepare_tags(self, data: dict) -> dict:
"""Prepare the tags before they're validated."""
if 'tags' not in data:
if "tags" not in data:
return data
tags: typing.List[str] = []
for tag in data['tags']:
for tag in data["tags"]:
tag = tag.lower()
# replace spaces with underscores
tag = tag.replace(' ', '_')
tag = tag.replace(" ", "_")
# remove any consecutive underscores
tag = re.sub('_{2,}', '_', tag)
tag = re.sub("_{2,}", "_", tag)
# remove any leading/trailing underscores
tag = tag.strip('_')
tag = tag.strip("_")
# drop any empty tags
if not tag or tag.isspace():
@ -66,72 +60,68 @@ class TopicSchema(Schema):
tags.append(tag)
data['tags'] = tags
data["tags"] = tags
return data
@validates('tags')
def validate_tags(
self,
value: typing.List[sqlalchemy_utils.Ltree],
) -> None:
@validates("tags")
def validate_tags(self, value: typing.List[sqlalchemy_utils.Ltree]) -> None:
"""Validate the tags field, raising an error if an issue exists.
Note that tags are validated by ensuring that each tag would be a valid
group path. This is definitely mixing concerns, but it's deliberate in
this case. It will allow for some interesting possibilities by ensuring
naming "compatibility" between groups and tags. For example, a popular
tag in a group could be converted into a sub-group easily.
Note that tags are validated by ensuring that each tag would be a valid group
path. This is definitely mixing concerns, but it's deliberate in this case. It
will allow for some interesting possibilities by ensuring naming "compatibility"
between groups and tags. For example, a popular tag in a group could be
converted into a sub-group easily.
"""
group_schema = GroupSchema(partial=True)
for tag in value:
try:
group_schema.validate({'path': tag})
group_schema.validate({"path": tag})
except ValidationError:
raise ValidationError('Tag %s is invalid' % tag)
raise ValidationError("Tag %s is invalid" % tag)
@pre_load
def prepare_markdown(self, data: dict) -> dict:
"""Prepare the markdown value before it's validated."""
if 'markdown' not in data:
if "markdown" not in data:
return data
# if the value is empty, convert it to None
if not data['markdown'] or data['markdown'].isspace():
data['markdown'] = None
if not data["markdown"] or data["markdown"].isspace():
data["markdown"] = None
return data
@pre_load
def prepare_link(self, data: dict) -> dict:
"""Prepare the link value before it's validated."""
if 'link' not in data:
if "link" not in data:
return data
# if the value is empty, convert it to None
if not data['link'] or data['link'].isspace():
data['link'] = None
if not data["link"] or data["link"].isspace():
data["link"] = None
return data
# prepend http:// to the link if it doesn't have a scheme
parsed = urlparse(data['link'])
parsed = urlparse(data["link"])
if not parsed.scheme:
data['link'] = 'http://' + data['link']
data["link"] = "http://" + data["link"]
return data
@validates_schema
def link_or_markdown(self, data: dict) -> None:
"""Fail validation unless at least one of link or markdown were set."""
if 'link' not in data and 'markdown' not in data:
if "link" not in data and "markdown" not in data:
return
link = data.get('link')
markdown = data.get('markdown')
link = data.get("link")
markdown = data.get("markdown")
if not (markdown or link):
raise ValidationError(
'Topics must have either markdown or a link.')
raise ValidationError("Topics must have either markdown or a link.")
class Meta:
"""Always use strict checking so error handlers are invoked."""

13
tildes/tildes/schemas/topic_listing.py

@ -17,25 +17,22 @@ class TopicListingSchema(Schema):
period = ShortTimePeriod(allow_none=True)
after = ID36()
before = ID36()
per_page = Integer(
validate=Range(min=1, max=100),
missing=DEFAULT_TOPICS_PER_PAGE,
)
rank_start = Integer(load_from='n', validate=Range(min=1), missing=None)
per_page = Integer(validate=Range(min=1, max=100), missing=DEFAULT_TOPICS_PER_PAGE)
rank_start = Integer(load_from="n", validate=Range(min=1), missing=None)
tag = Ltree(missing=None)
unfiltered = Boolean(missing=False)
@validates_schema
def either_after_or_before(self, data: dict) -> None:
"""Fail validation if both after and before were specified."""
if data.get('after') and data.get('before'):
if data.get("after") and data.get("before"):
raise ValidationError("Can't specify both after and before.")
@pre_load
def reset_rank_start_on_first_page(self, data: dict) -> dict:
"""Reset rank_start to 1 if this is a first page (no before/after)."""
if not (data.get('before') or data.get('after')):
data['rank_start'] = 1
if not (data.get("before") or data.get("after")):
data["rank_start"] = 1
return data

58
tildes/tildes/schemas/user.py

@ -2,13 +2,7 @@
import re
from marshmallow import (
post_dump,
pre_load,
Schema,
validates,
validates_schema,
)
from marshmallow import post_dump, pre_load, Schema, validates, validates_schema
from marshmallow.exceptions import ValidationError
from marshmallow.fields import Boolean, DateTime, Email, String
from marshmallow.validate import Length, Regexp
@ -22,15 +16,18 @@ USERNAME_MAX_LENGTH = 20
# Valid username regex, encodes the following:
# - must start with a number or letter
# - must end with a number or letter
# - the middle can contain numbers, letters, underscores and dashes, but no
# more than one underscore/dash consecutively (this includes both "_-" and
# "-_" sequences being invalid)
# - the middle can contain numbers, letters, underscores and dashes, but no more than
# one underscore/dash consecutively (this includes both "_-" and "-_" sequences
# being invalid)
# Note: this regex does not contain any length checks, must be done separately
# fmt: off
USERNAME_VALID_REGEX = re.compile(
"^[a-z0-9]" # start
"([a-z0-9]|[_-](?![_-]))*" # middle
"[a-z0-9]$", # end
re.IGNORECASE)
re.IGNORECASE,
)
# fmt: on
PASSWORD_MIN_LENGTH = 8
@ -48,29 +45,26 @@ class UserSchema(Schema):
required=True,
)
password = String(
validate=Length(min=PASSWORD_MIN_LENGTH),
required=True,
load_only=True,
validate=Length(min=PASSWORD_MIN_LENGTH), required=True, load_only=True
)
email_address = Email(allow_none=True, load_only=True)
email_address_note = String(
validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
email_address_note = String(validate=Length(max=EMAIL_ADDRESS_NOTE_MAX_LENGTH))
created_time = DateTime(dump_only=True)
track_comment_visits = Boolean()
@post_dump
def anonymize_username(self, data: dict) -> dict:
"""Hide the username if the dumping context specifies to do so."""
if 'username' in data and self.context.get('hide_username'):
data['username'] = '<unknown>'
if "username" in data and self.context.get("hide_username"):
data["username"] = "<unknown>"
return data
@validates_schema
def username_pass_not_substrings(self, data: dict) -> None:
"""Ensure the username isn't in the password and vice versa."""
username = data.get('username')
password = data.get('password')
username = data.get("username")
password = data.get("password")
if not (username and password):
return
@ -78,32 +72,32 @@ class UserSchema(Schema):
password = password.lower()
if username in password:
raise ValidationError('Password cannot contain username')
raise ValidationError("Password cannot contain username")
if password in username:
raise ValidationError('Username cannot contain password')
raise ValidationError("Username cannot contain password")
@validates('password')
@validates("password")
def password_not_breached(self, value: str) -> None:
"""Validate that the password is not in the breached-passwords list.
Requires check_breached_passwords be True in the schema's context.
"""
if not self.context.get('check_breached_passwords'):
if not self.context.get("check_breached_passwords"):
return
if is_breached_password(value):
raise ValidationError('That password exists in a data breach')
raise ValidationError("That password exists in a data breach")
@pre_load
def prepare_email_address(self, data: dict) -> dict:
"""Prepare the email address value before it's validated."""
if 'email_address' not in data:
if "email_address" not in data:
return data
# if the value is empty, convert it to None
if not data['email_address'] or data['email_address'].isspace():
data['email_address'] = None
if not data["email_address"] or data["email_address"].isspace():
data["email_address"] = None
return data
@ -116,13 +110,13 @@ class UserSchema(Schema):
def is_valid_username(username: str) -> bool:
"""Return whether the username is valid or not.
Simple convenience wrapper that uses the schema to validate a username,
useful in cases where a simple valid/invalid result is needed without
worrying about the specific reason for invalidity.
Simple convenience wrapper that uses the schema to validate a username, useful in
cases where a simple valid/invalid result is needed without worrying about the
specific reason for invalidity.
"""
schema = UserSchema(partial=True)
try:
schema.validate({'username': username})
schema.validate({"username": username})
except ValidationError:
return False

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save