mirror of https://gitlab.com/tildes/tildes.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
7.4 KiB
246 lines
7.4 KiB
# Copyright (c) 2018 Tildes contributors <code@tildes.net>
|
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
from datetime import timedelta
|
|
from itertools import permutations
|
|
from random import randint
|
|
|
|
from pytest import raises
|
|
|
|
from tildes.lib.ratelimit import (
|
|
RATE_LIMITED_ACTIONS,
|
|
RateLimitedAction,
|
|
RateLimitError,
|
|
RateLimitResult,
|
|
)
|
|
|
|
|
|
def test_all_rate_limited_action_names_unique():
|
|
"""Ensure all the RATE_LIMITED_ACTIONS defined have unique names."""
|
|
seen_names = set()
|
|
|
|
for action in RATE_LIMITED_ACTIONS.values():
|
|
assert action.name not in seen_names
|
|
seen_names.add(action.name)
|
|
|
|
|
|
def test_action_with_all_types_disabled():
|
|
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
|
|
with raises(ValueError):
|
|
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
|
|
|
|
|
|
def test_check_by_user_id_disabled():
|
|
"""Ensure non-by_user RateLimitedAction can't be checked by user_id."""
|
|
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
|
|
|
|
with raises(RateLimitError):
|
|
action.check_for_user_id(1)
|
|
|
|
|
|
def test_check_by_ip_disabled():
|
|
"""Ensure non-by_ip RateLimitedAction can't be checked by ip."""
|
|
action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
|
|
|
|
with raises(RateLimitError):
|
|
action.check_for_ip("123.123.123.123")
|
|
|
|
|
|
def test_simple_rate_limiting_by_user_id(redis):
|
|
"""Ensure simple rate-limiting by user_id is working."""
|
|
limit = 5
|
|
user_id = 1
|
|
|
|
# define an action with max_burst equal to the full limit
|
|
action = RateLimitedAction(
|
|
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
|
|
)
|
|
|
|
# run the action the full number of times, should all be allowed
|
|
for _ in range(limit):
|
|
result = action.check_for_user_id(user_id)
|
|
assert result.is_allowed
|
|
|
|
# try one more time, should be rejected
|
|
result = action.check_for_user_id(user_id)
|
|
assert not result.is_allowed
|
|
|
|
|
|
def test_different_user_ids_limited_separately(redis):
|
|
"""Ensure one user being rate-limited doesn't affect a different one."""
|
|
limit = 5
|
|
user_id = 1
|
|
|
|
action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
|
|
|
|
# check the action for the first user_id until it's blocked
|
|
result = action.check_for_user_id(user_id)
|
|
while result.is_allowed:
|
|
result = action.check_for_user_id(user_id)
|
|
|
|
# it should still be allowed for a different user_id
|
|
assert action.check_for_user_id(user_id + 1)
|
|
|
|
|
|
def test_max_burst_defaults_to_half(redis):
|
|
"""Ensure that unspecified max_burst on a RateLimitedAction allows half."""
|
|
limit = 10
|
|
user_id = 1
|
|
|
|
action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
|
|
|
|
# see how many times we can do the action until it gets blocked
|
|
count = 0
|
|
while True:
|
|
result = action.check_for_user_id(user_id)
|
|
if result.is_allowed:
|
|
count += 1
|
|
else:
|
|
break
|
|
|
|
assert count == limit // 2
|
|
|
|
|
|
def test_time_until_retry(redis):
|
|
"""Ensure an unbursted limit's time_until_retry is the expected value."""
|
|
user_id = 1
|
|
period = timedelta(seconds=60)
|
|
limit = 2
|
|
|
|
# create an action with no burst allowed, which will force the actions to be spaced
|
|
# "evenly" across the limit
|
|
action = RateLimitedAction(
|
|
"test", period=period, limit=limit, max_burst=1, redis=redis
|
|
)
|
|
|
|
# first usage should be fine
|
|
result = action.check_for_user_id(user_id)
|
|
assert result.is_allowed
|
|
|
|
# second should fail, and require a wait of (period / limit) - 1 sec
|
|
result = action.check_for_user_id(user_id)
|
|
assert not result.is_allowed
|
|
assert result.time_until_retry == (period / limit) - timedelta(seconds=1)
|
|
|
|
|
|
def test_remaining_limit(redis):
|
|
"""Ensure a limit's "remaining limit" decreases as expected."""
|
|
user_id = 1
|
|
limit = 10
|
|
|
|
# create an action allowing the full limit as a burst
|
|
action = RateLimitedAction(
|
|
"test", timedelta(days=1), limit, max_burst=limit, redis=redis
|
|
)
|
|
|
|
for count in range(1, limit + 1):
|
|
result = action.check_for_user_id(user_id)
|
|
assert result.remaining_limit == limit - count
|
|
|
|
|
|
def test_simple_rate_limiting_by_ip(redis):
|
|
"""Ensure simple rate-limiting by IP address is working."""
|
|
limit = 5
|
|
ip = "123.123.123.123"
|
|
|
|
# define an action with max_burst equal to the full limit
|
|
action = RateLimitedAction(
|
|
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
|
|
)
|
|
|
|
# run the action the full number of times, should all be allowed
|
|
for _ in range(limit):
|
|
result = action.check_for_ip(ip)
|
|
assert result.is_allowed
|
|
|
|
# try one more time, should be rejected
|
|
result = action.check_for_ip(ip)
|
|
assert not result.is_allowed
|
|
|
|
|
|
def test_check_for_ip_invalid_address():
|
|
"""Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
|
|
ip = "123.456.789.123"
|
|
|
|
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
|
|
|
|
with raises(ValueError):
|
|
action.check_for_ip(ip)
|
|
|
|
|
|
def test_reset_for_ip_invalid_address():
|
|
"""Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
|
|
ip = "123.456.789.123"
|
|
|
|
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
|
|
|
|
with raises(ValueError):
|
|
action.reset_for_ip(ip)
|
|
|
|
|
|
def test_merged_results_single():
|
|
"""Ensure "merging" a single result just returns the same one."""
|
|
result = RateLimitResult(
|
|
is_allowed=True,
|
|
total_limit=50,
|
|
remaining_limit=22,
|
|
time_until_max=timedelta(seconds=256),
|
|
)
|
|
|
|
assert RateLimitResult.merged_result([result]) == result
|
|
|
|
|
|
def test_merged_results():
|
|
"""Ensure merging RateLimitResults gives the expected result."""
|
|
results = [
|
|
RateLimitResult(
|
|
is_allowed=True,
|
|
total_limit=20,
|
|
remaining_limit=15,
|
|
time_until_max=timedelta(seconds=90),
|
|
),
|
|
RateLimitResult(
|
|
is_allowed=False,
|
|
total_limit=10,
|
|
remaining_limit=0,
|
|
time_until_max=timedelta(seconds=30),
|
|
time_until_retry=timedelta(seconds=10),
|
|
),
|
|
RateLimitResult(
|
|
is_allowed=True,
|
|
total_limit=30,
|
|
remaining_limit=20,
|
|
time_until_max=timedelta(seconds=60),
|
|
),
|
|
]
|
|
|
|
expected_merged_result = RateLimitResult(
|
|
is_allowed=False,
|
|
total_limit=10,
|
|
remaining_limit=0,
|
|
time_until_max=timedelta(seconds=90),
|
|
time_until_retry=timedelta(seconds=10),
|
|
)
|
|
|
|
# try merging all permutations to ensure ordering isn't a factor
|
|
for permutation in permutations(results):
|
|
merged_result = RateLimitResult.merged_result(permutation)
|
|
assert merged_result == expected_merged_result
|
|
|
|
|
|
def test_merged_all_allowed():
|
|
"""Ensure a merged result from all allowed results is also allowed."""
|
|
|
|
def random_allowed_result():
|
|
"""Return a RateLimitResult with is_allowed=True, otherwise random."""
|
|
return RateLimitResult(
|
|
is_allowed=True,
|
|
total_limit=randint(1, 100),
|
|
remaining_limit=randint(1, 100),
|
|
time_until_max=timedelta(randint(1, 100)),
|
|
)
|
|
|
|
# try merging a few different sets of different sizes
|
|
for num_results in range(2, 6):
|
|
results = [random_allowed_result() for _ in range(num_results)]
|
|
assert RateLimitResult.merged_result(results).is_allowed
|