97 lines
3.2 KiB

"""Connection test module."""
from inspect import iscoroutinefunction, signature
import pytest
from keycloak.connection import ConnectionManager
from keycloak.exceptions import KeycloakConnectionError
def test_connection_proxy() -> None:
"""Test proxies of connection manager."""
cm = ConnectionManager(
base_url="http://test.test",
proxies={"http://test.test": "http://localhost:8080"},
)
assert cm._s.proxies == {"http://test.test": "http://localhost:8080"}
def test_headers() -> None:
"""Test headers manipulation."""
cm = ConnectionManager(base_url="http://test.test", headers={"H": "A"})
assert cm.param_headers(key="H") == "A"
assert cm.param_headers(key="A") is None
cm.clean_headers()
assert cm.headers == {}
cm.add_param_headers(key="H", value="B")
assert cm.exist_param_headers(key="H")
assert not cm.exist_param_headers(key="B")
cm.del_param_headers(key="H")
assert not cm.exist_param_headers(key="H")
def test_bad_connection() -> None:
"""Test bad connection."""
cm = ConnectionManager(base_url="http://not.real.domain")
with pytest.raises(KeycloakConnectionError):
cm.raw_get(path="bad")
with pytest.raises(KeycloakConnectionError):
cm.raw_delete(path="bad")
with pytest.raises(KeycloakConnectionError):
cm.raw_post(path="bad", data={})
with pytest.raises(KeycloakConnectionError):
cm.raw_put(path="bad", data={})
@pytest.mark.asyncio
async def a_test_bad_connection() -> None:
"""Test bad connection."""
cm = ConnectionManager(base_url="http://not.real.domain")
with pytest.raises(KeycloakConnectionError):
await cm.a_raw_get(path="bad")
with pytest.raises(KeycloakConnectionError):
await cm.a_raw_delete(path="bad")
with pytest.raises(KeycloakConnectionError):
await cm.a_raw_post(path="bad", data={})
with pytest.raises(KeycloakConnectionError):
await cm.a_raw_put(path="bad", data={})
def test_counter_part() -> None:
"""Test that each function has its async counter part."""
con_methods = [
func for func in dir(ConnectionManager) if callable(getattr(ConnectionManager, func))
]
sync_methods = [
method
for method in con_methods
if not method.startswith("a_") and not method.startswith("_")
]
async_methods = [
method for method in con_methods if iscoroutinefunction(getattr(ConnectionManager, method))
]
for method in sync_methods:
if method in [
"aclose",
"add_param_headers",
"del_param_headers",
"clean_headers",
"exist_param_headers",
"param_headers",
]:
continue
async_method = f"a_{method}"
assert (async_method in con_methods) is True
sync_sign = signature(getattr(ConnectionManager, method))
async_sign = signature(getattr(ConnectionManager, async_method))
assert sync_sign.parameters == async_sign.parameters
for async_method in async_methods:
if async_method in ["aclose"]:
continue
if async_method[2:].startswith("_"):
continue
assert async_method[2:] in sync_methods