You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

246 lines
7.4 KiB

  1. # Copyright (c) 2018 Tildes contributors <code@tildes.net>
  2. # SPDX-License-Identifier: AGPL-3.0-or-later
  3. from datetime import timedelta
  4. from itertools import permutations
  5. from random import randint
  6. from pytest import raises
  7. from tildes.lib.ratelimit import (
  8. RATE_LIMITED_ACTIONS,
  9. RateLimitedAction,
  10. RateLimitError,
  11. RateLimitResult,
  12. )
  13. def test_all_rate_limited_action_names_unique():
  14. """Ensure all the RATE_LIMITED_ACTIONS defined have unique names."""
  15. seen_names = set()
  16. for action in RATE_LIMITED_ACTIONS.values():
  17. assert action.name not in seen_names
  18. seen_names.add(action.name)
  19. def test_action_with_all_types_disabled():
  20. """Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
  21. with raises(ValueError):
  22. RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
  23. def test_check_by_user_id_disabled():
  24. """Ensure non-by_user RateLimitedAction can't be checked by user_id."""
  25. action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False)
  26. with raises(RateLimitError):
  27. action.check_for_user_id(1)
  28. def test_check_by_ip_disabled():
  29. """Ensure non-by_ip RateLimitedAction can't be checked by ip."""
  30. action = RateLimitedAction("test", timedelta(hours=1), 5, by_ip=False)
  31. with raises(RateLimitError):
  32. action.check_for_ip("123.123.123.123")
  33. def test_simple_rate_limiting_by_user_id(redis):
  34. """Ensure simple rate-limiting by user_id is working."""
  35. limit = 5
  36. user_id = 1
  37. # define an action with max_burst equal to the full limit
  38. action = RateLimitedAction(
  39. "testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
  40. )
  41. # run the action the full number of times, should all be allowed
  42. for _ in range(limit):
  43. result = action.check_for_user_id(user_id)
  44. assert result.is_allowed
  45. # try one more time, should be rejected
  46. result = action.check_for_user_id(user_id)
  47. assert not result.is_allowed
  48. def test_different_user_ids_limited_separately(redis):
  49. """Ensure one user being rate-limited doesn't affect a different one."""
  50. limit = 5
  51. user_id = 1
  52. action = RateLimitedAction("test", timedelta(hours=1), limit, redis=redis)
  53. # check the action for the first user_id until it's blocked
  54. result = action.check_for_user_id(user_id)
  55. while result.is_allowed:
  56. result = action.check_for_user_id(user_id)
  57. # it should still be allowed for a different user_id
  58. assert action.check_for_user_id(user_id + 1)
  59. def test_max_burst_defaults_to_half(redis):
  60. """Ensure that unspecified max_burst on a RateLimitedAction allows half."""
  61. limit = 10
  62. user_id = 1
  63. action = RateLimitedAction("test", timedelta(days=1), limit, redis=redis)
  64. # see how many times we can do the action until it gets blocked
  65. count = 0
  66. while True:
  67. result = action.check_for_user_id(user_id)
  68. if result.is_allowed:
  69. count += 1
  70. else:
  71. break
  72. assert count == limit // 2
  73. def test_time_until_retry(redis):
  74. """Ensure an unbursted limit's time_until_retry is the expected value."""
  75. user_id = 1
  76. period = timedelta(seconds=60)
  77. limit = 2
  78. # create an action with no burst allowed, which will force the actions to be spaced
  79. # "evenly" across the limit
  80. action = RateLimitedAction(
  81. "test", period=period, limit=limit, max_burst=1, redis=redis
  82. )
  83. # first usage should be fine
  84. result = action.check_for_user_id(user_id)
  85. assert result.is_allowed
  86. # second should fail, and require a wait of (period / limit) - 1 sec
  87. result = action.check_for_user_id(user_id)
  88. assert not result.is_allowed
  89. assert result.time_until_retry == (period / limit) - timedelta(seconds=1)
  90. def test_remaining_limit(redis):
  91. """Ensure a limit's "remaining limit" decreases as expected."""
  92. user_id = 1
  93. limit = 10
  94. # create an action allowing the full limit as a burst
  95. action = RateLimitedAction(
  96. "test", timedelta(days=1), limit, max_burst=limit, redis=redis
  97. )
  98. for count in range(1, limit + 1):
  99. result = action.check_for_user_id(user_id)
  100. assert result.remaining_limit == limit - count
  101. def test_simple_rate_limiting_by_ip(redis):
  102. """Ensure simple rate-limiting by IP address is working."""
  103. limit = 5
  104. ip = "123.123.123.123"
  105. # define an action with max_burst equal to the full limit
  106. action = RateLimitedAction(
  107. "testaction", timedelta(hours=1), limit, max_burst=limit, redis=redis
  108. )
  109. # run the action the full number of times, should all be allowed
  110. for _ in range(limit):
  111. result = action.check_for_ip(ip)
  112. assert result.is_allowed
  113. # try one more time, should be rejected
  114. result = action.check_for_ip(ip)
  115. assert not result.is_allowed
  116. def test_check_for_ip_invalid_address():
  117. """Ensure RateLimitedAction.check_for_ip can't take an invalid IP."""
  118. ip = "123.456.789.123"
  119. action = RateLimitedAction("testaction", timedelta(hours=1), 10)
  120. with raises(ValueError):
  121. action.check_for_ip(ip)
  122. def test_reset_for_ip_invalid_address():
  123. """Ensure RateLimitedAction.reset_for_ip can't take an invalid IP."""
  124. ip = "123.456.789.123"
  125. action = RateLimitedAction("testaction", timedelta(hours=1), 10)
  126. with raises(ValueError):
  127. action.reset_for_ip(ip)
  128. def test_merged_results_single():
  129. """Ensure "merging" a single result just returns the same one."""
  130. result = RateLimitResult(
  131. is_allowed=True,
  132. total_limit=50,
  133. remaining_limit=22,
  134. time_until_max=timedelta(seconds=256),
  135. )
  136. assert RateLimitResult.merged_result([result]) == result
  137. def test_merged_results():
  138. """Ensure merging RateLimitResults gives the expected result."""
  139. results = [
  140. RateLimitResult(
  141. is_allowed=True,
  142. total_limit=20,
  143. remaining_limit=15,
  144. time_until_max=timedelta(seconds=90),
  145. ),
  146. RateLimitResult(
  147. is_allowed=False,
  148. total_limit=10,
  149. remaining_limit=0,
  150. time_until_max=timedelta(seconds=30),
  151. time_until_retry=timedelta(seconds=10),
  152. ),
  153. RateLimitResult(
  154. is_allowed=True,
  155. total_limit=30,
  156. remaining_limit=20,
  157. time_until_max=timedelta(seconds=60),
  158. ),
  159. ]
  160. expected_merged_result = RateLimitResult(
  161. is_allowed=False,
  162. total_limit=10,
  163. remaining_limit=0,
  164. time_until_max=timedelta(seconds=90),
  165. time_until_retry=timedelta(seconds=10),
  166. )
  167. # try merging all permutations to ensure ordering isn't a factor
  168. for permutation in permutations(results):
  169. merged_result = RateLimitResult.merged_result(permutation)
  170. assert merged_result == expected_merged_result
  171. def test_merged_all_allowed():
  172. """Ensure a merged result from all allowed results is also allowed."""
  173. def random_allowed_result():
  174. """Return a RateLimitResult with is_allowed=True, otherwise random."""
  175. return RateLimitResult(
  176. is_allowed=True,
  177. total_limit=randint(1, 100),
  178. remaining_limit=randint(1, 100),
  179. time_until_max=timedelta(randint(1, 100)),
  180. )
  181. # try merging a few different sets of different sizes
  182. for num_results in range(2, 6):
  183. results = [random_allowed_result() for _ in range(num_results)]
  184. assert RateLimitResult.merged_result(results).is_allowed