97 lines
3.2 KiB

  1. """Connection test module."""
  2. from inspect import iscoroutinefunction, signature
  3. import pytest
  4. from keycloak.connection import ConnectionManager
  5. from keycloak.exceptions import KeycloakConnectionError
  6. def test_connection_proxy() -> None:
  7. """Test proxies of connection manager."""
  8. cm = ConnectionManager(
  9. base_url="http://test.test",
  10. proxies={"http://test.test": "http://localhost:8080"},
  11. )
  12. assert cm._s.proxies == {"http://test.test": "http://localhost:8080"}
  13. def test_headers() -> None:
  14. """Test headers manipulation."""
  15. cm = ConnectionManager(base_url="http://test.test", headers={"H": "A"})
  16. assert cm.param_headers(key="H") == "A"
  17. assert cm.param_headers(key="A") is None
  18. cm.clean_headers()
  19. assert cm.headers == {}
  20. cm.add_param_headers(key="H", value="B")
  21. assert cm.exist_param_headers(key="H")
  22. assert not cm.exist_param_headers(key="B")
  23. cm.del_param_headers(key="H")
  24. assert not cm.exist_param_headers(key="H")
  25. def test_bad_connection() -> None:
  26. """Test bad connection."""
  27. cm = ConnectionManager(base_url="http://not.real.domain")
  28. with pytest.raises(KeycloakConnectionError):
  29. cm.raw_get(path="bad")
  30. with pytest.raises(KeycloakConnectionError):
  31. cm.raw_delete(path="bad")
  32. with pytest.raises(KeycloakConnectionError):
  33. cm.raw_post(path="bad", data={})
  34. with pytest.raises(KeycloakConnectionError):
  35. cm.raw_put(path="bad", data={})
  36. @pytest.mark.asyncio
  37. async def a_test_bad_connection() -> None:
  38. """Test bad connection."""
  39. cm = ConnectionManager(base_url="http://not.real.domain")
  40. with pytest.raises(KeycloakConnectionError):
  41. await cm.a_raw_get(path="bad")
  42. with pytest.raises(KeycloakConnectionError):
  43. await cm.a_raw_delete(path="bad")
  44. with pytest.raises(KeycloakConnectionError):
  45. await cm.a_raw_post(path="bad", data={})
  46. with pytest.raises(KeycloakConnectionError):
  47. await cm.a_raw_put(path="bad", data={})
  48. def test_counter_part() -> None:
  49. """Test that each function has its async counter part."""
  50. con_methods = [
  51. func for func in dir(ConnectionManager) if callable(getattr(ConnectionManager, func))
  52. ]
  53. sync_methods = [
  54. method
  55. for method in con_methods
  56. if not method.startswith("a_") and not method.startswith("_")
  57. ]
  58. async_methods = [
  59. method for method in con_methods if iscoroutinefunction(getattr(ConnectionManager, method))
  60. ]
  61. for method in sync_methods:
  62. if method in [
  63. "aclose",
  64. "add_param_headers",
  65. "del_param_headers",
  66. "clean_headers",
  67. "exist_param_headers",
  68. "param_headers",
  69. ]:
  70. continue
  71. async_method = f"a_{method}"
  72. assert (async_method in con_methods) is True
  73. sync_sign = signature(getattr(ConnectionManager, method))
  74. async_sign = signature(getattr(ConnectionManager, async_method))
  75. assert sync_sign.parameters == async_sign.parameters
  76. for async_method in async_methods:
  77. if async_method in ["aclose"]:
  78. continue
  79. if async_method[2:].startswith("_"):
  80. continue
  81. assert async_method[2:] in sync_methods