96 lines
3.1 KiB

  1. """Connection test module."""
  2. from inspect import iscoroutinefunction, signature
  3. import pytest
  4. from keycloak.connection import ConnectionManager
  5. from keycloak.exceptions import KeycloakConnectionError
  6. def test_connection_proxy():
  7. """Test proxies of connection manager."""
  8. cm = ConnectionManager(
  9. base_url="http://test.test", proxies={"http://test.test": "http://localhost:8080"}
  10. )
  11. assert cm._s.proxies == {"http://test.test": "http://localhost:8080"}
  12. def test_headers():
  13. """Test headers manipulation."""
  14. cm = ConnectionManager(base_url="http://test.test", headers={"H": "A"})
  15. assert cm.param_headers(key="H") == "A"
  16. assert cm.param_headers(key="A") is None
  17. cm.clean_headers()
  18. assert cm.headers == dict()
  19. cm.add_param_headers(key="H", value="B")
  20. assert cm.exist_param_headers(key="H")
  21. assert not cm.exist_param_headers(key="B")
  22. cm.del_param_headers(key="H")
  23. assert not cm.exist_param_headers(key="H")
  24. def test_bad_connection():
  25. """Test bad connection."""
  26. cm = ConnectionManager(base_url="http://not.real.domain")
  27. with pytest.raises(KeycloakConnectionError):
  28. cm.raw_get(path="bad")
  29. with pytest.raises(KeycloakConnectionError):
  30. cm.raw_delete(path="bad")
  31. with pytest.raises(KeycloakConnectionError):
  32. cm.raw_post(path="bad", data={})
  33. with pytest.raises(KeycloakConnectionError):
  34. cm.raw_put(path="bad", data={})
  35. @pytest.mark.asyncio
  36. async def a_test_bad_connection():
  37. """Test bad connection."""
  38. cm = ConnectionManager(base_url="http://not.real.domain")
  39. with pytest.raises(KeycloakConnectionError):
  40. await cm.a_raw_get(path="bad")
  41. with pytest.raises(KeycloakConnectionError):
  42. await cm.a_raw_delete(path="bad")
  43. with pytest.raises(KeycloakConnectionError):
  44. await cm.a_raw_post(path="bad", data={})
  45. with pytest.raises(KeycloakConnectionError):
  46. await cm.a_raw_put(path="bad", data={})
  47. def test_counter_part():
  48. """Test that each function has its async counter part."""
  49. con_methods = [
  50. func for func in dir(ConnectionManager) if callable(getattr(ConnectionManager, func))
  51. ]
  52. sync_methods = [
  53. method
  54. for method in con_methods
  55. if not method.startswith("a_") and not method.startswith("_")
  56. ]
  57. async_methods = [
  58. method for method in con_methods if iscoroutinefunction(getattr(ConnectionManager, method))
  59. ]
  60. for method in sync_methods:
  61. if method in [
  62. "aclose",
  63. "add_param_headers",
  64. "del_param_headers",
  65. "clean_headers",
  66. "exist_param_headers",
  67. "param_headers",
  68. ]:
  69. continue
  70. async_method = f"a_{method}"
  71. assert (async_method in con_methods) is True
  72. sync_sign = signature(getattr(ConnectionManager, method))
  73. async_sign = signature(getattr(ConnectionManager, async_method))
  74. assert sync_sign.parameters == async_sign.parameters
  75. for async_method in async_methods:
  76. if async_method in ["aclose"]:
  77. continue
  78. if async_method[2:].startswith("_"):
  79. continue
  80. assert async_method[2:] in sync_methods