You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

246 lines
7.4 KiB

# Copyright (c) 2018 Tildes contributors <code@tildes.net>
# SPDX-License-Identifier: AGPL-3.0-or-later
from datetime import timedelta
from itertools import permutations
from random import randint
from pytest import raises
from tildes.lib.ratelimit import (
RateLimitedAction,
RateLimitError,
RateLimitResult,
RATE_LIMITED_ACTIONS,
)
def test_all_rate_limited_action_names_unique():
"""Ensure all the RATE_LIMITED_ACTIONS defined have unique names."""
seen_names = set()
for action in RATE_LIMITED_ACTIONS.values():
assert action.name not in seen_names
seen_names.add(action.name)
def test_action_with_all_types_disabled():
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
with raises(ValueError):
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
def test_check_by_user_id_disabled():
"""Ensure non-by_user RateLimitedAction can't be checked by user_id."""
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
with raises(RateLimitError):
action.check_for_user_id(1)
def test_check_by_ip_disabled():
"""Ensure non-by_ip RateLimitedAction can't be checked by ip."""
action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
with raises(RateLimitError):
action.check_for_ip("123.123.123.123")
def test_simple_rate_limiting_by_user_id(redis):
"""Ensure simple rate-limiting by user_id is working."""
limit = 5
user_id = 1
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
result = action.check_for_user_id(user_id)
assert result.is_allowed
# try one more time, should be rejected
result = action.check_for_user_id(user_id)
assert not result.is_allowed
def test_different_user_ids_limited_separately(redis):
"""Ensure one user being rate-limited doesn't affect a different one."""
limit = 5
user_id = 1
action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
# check the action for the first user_id until it's blocked
result = action.check_for_user_id(user_id)
while result.is_allowed:
result = action.check_for_user_id(user_id)
# it should still be allowed for a different user_id
assert action.check_for_user_id(user_id + 1)
def test_max_burst_defaults_to_half(redis):
"""Ensure that unspecified max_burst on a RateLimitedAction allows half."""
limit = 10
user_id = 1
action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
# see how many times we can do the action until it gets blocked
count = 0
while True:
result = action.check_for_user_id(user_id)
if result.is_allowed:
count += 1
else:
break
assert count == limit // 2
def test_time_until_retry(redis):
"""Ensure an unbursted limit's time_until_retry is the expected value."""
user_id = 1
period = timedelta(seconds=60)
limit = 2
# create an action with no burst allowed, which will force the actions to be spaced
# "evenly" across the limit
action = RateLimitedAction(
"test", period=period, limit=limit, max_burst=1, redis=redis
)
# first usage should be fine
result = action.check_for_user_id(user_id)
assert result.is_allowed
# second should fail, and require a wait of (period / limit) - 1 sec
result = action.check_for_user_id(user_id)
assert not result.is_allowed
assert result.time_until_retry == (period / limit) - timedelta(seconds=1)
def test_remaining_limit(redis):
"""Ensure a limit's "remaining limit" decreases as expected."""
user_id = 1
limit = 10
# create an action allowing the full limit as a burst
action = RateLimitedAction(
"test", timedelta(days=1), limit, max_burst=limit, redis=redis
)
for count in range(1, limit + 1):
result = action.check_for_user_id(user_id)
assert result.remaining_limit == limit - count
def test_simple_rate_limiting_by_ip(redis):
"""Ensure simple rate-limiting by IP address is working."""
limit = 5
ip = "123.123.123.123"
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
"testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
result = action.check_for_ip(ip)
assert result.is_allowed
# try one more time, should be rejected
result = action.check_for_ip(ip)
assert not result.is_allowed
def test_check_for_ip_invalid_address():
"""Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
ip = "123.456.789.123"
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.check_for_ip(ip)
def test_reset_for_ip_invalid_address():
"""Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
ip = "123.456.789.123"
action = RateLimitedAction("testaction", timedelta(hours=1), 10)
with raises(ValueError):
action.reset_for_ip(ip)
def test_merged_results_single():
"""Ensure "merging" a single result just returns the same one."""
result = RateLimitResult(
is_allowed=True,
total_limit=50,
remaining_limit=22,
time_until_max=timedelta(seconds=256),
)
assert RateLimitResult.merged_result([result]) == result
def test_merged_results():
"""Ensure merging RateLimitResults gives the expected result."""
results = [
RateLimitResult(
is_allowed=True,
total_limit=20,
remaining_limit=15,
time_until_max=timedelta(seconds=90),
),
RateLimitResult(
is_allowed=False,
total_limit=10,
remaining_limit=0,
time_until_max=timedelta(seconds=30),
time_until_retry=timedelta(seconds=10),
),
RateLimitResult(
is_allowed=True,
total_limit=30,
remaining_limit=20,
time_until_max=timedelta(seconds=60),
),
]
expected_merged_result = RateLimitResult(
is_allowed=False,
total_limit=10,
remaining_limit=0,
time_until_max=timedelta(seconds=90),
time_until_retry=timedelta(seconds=10),
)
# try merging all permutations to ensure ordering isn't a factor
for permutation in permutations(results):
merged_result = RateLimitResult.merged_result(permutation)
assert merged_result == expected_merged_result
def test_merged_all_allowed():
"""Ensure a merged result from all allowed results is also allowed."""
def random_allowed_result():
"""Return a RateLimitResult with is_allowed=True, otherwise random."""
return RateLimitResult(
is_allowed=True,
total_limit=randint(1, 100),
remaining_limit=randint(1, 100),
time_until_max=timedelta(randint(1, 100)),
)
# try merging a few different sets of different sizes
for num_results in range(2, 6):
results = [random_allowed_result() for _ in range(num_results)]
assert RateLimitResult.merged_result(results).is_allowed