You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.1 KiB
96 lines
3.1 KiB
"""Connection test module."""
|
|
|
|
from inspect import iscoroutinefunction, signature
|
|
|
|
import pytest
|
|
|
|
from keycloak.connection import ConnectionManager
|
|
from keycloak.exceptions import KeycloakConnectionError
|
|
|
|
|
|
def test_connection_proxy():
|
|
"""Test proxies of connection manager."""
|
|
cm = ConnectionManager(
|
|
base_url="http://test.test", proxies={"http://test.test": "http://localhost:8080"}
|
|
)
|
|
assert cm._s.proxies == {"http://test.test": "http://localhost:8080"}
|
|
|
|
|
|
def test_headers():
|
|
"""Test headers manipulation."""
|
|
cm = ConnectionManager(base_url="http://test.test", headers={"H": "A"})
|
|
assert cm.param_headers(key="H") == "A"
|
|
assert cm.param_headers(key="A") is None
|
|
cm.clean_headers()
|
|
assert cm.headers == dict()
|
|
cm.add_param_headers(key="H", value="B")
|
|
assert cm.exist_param_headers(key="H")
|
|
assert not cm.exist_param_headers(key="B")
|
|
cm.del_param_headers(key="H")
|
|
assert not cm.exist_param_headers(key="H")
|
|
|
|
|
|
def test_bad_connection():
|
|
"""Test bad connection."""
|
|
cm = ConnectionManager(base_url="http://not.real.domain")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
cm.raw_get(path="bad")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
cm.raw_delete(path="bad")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
cm.raw_post(path="bad", data={})
|
|
with pytest.raises(KeycloakConnectionError):
|
|
cm.raw_put(path="bad", data={})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def a_test_bad_connection():
|
|
"""Test bad connection."""
|
|
cm = ConnectionManager(base_url="http://not.real.domain")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
await cm.a_raw_get(path="bad")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
await cm.a_raw_delete(path="bad")
|
|
with pytest.raises(KeycloakConnectionError):
|
|
await cm.a_raw_post(path="bad", data={})
|
|
with pytest.raises(KeycloakConnectionError):
|
|
await cm.a_raw_put(path="bad", data={})
|
|
|
|
|
|
def test_counter_part():
|
|
"""Test that each function has its async counter part."""
|
|
con_methods = [
|
|
func for func in dir(ConnectionManager) if callable(getattr(ConnectionManager, func))
|
|
]
|
|
sync_methods = [
|
|
method
|
|
for method in con_methods
|
|
if not method.startswith("a_") and not method.startswith("_")
|
|
]
|
|
async_methods = [
|
|
method for method in con_methods if iscoroutinefunction(getattr(ConnectionManager, method))
|
|
]
|
|
|
|
for method in sync_methods:
|
|
if method in [
|
|
"aclose",
|
|
"add_param_headers",
|
|
"del_param_headers",
|
|
"clean_headers",
|
|
"exist_param_headers",
|
|
"param_headers",
|
|
]:
|
|
continue
|
|
async_method = f"a_{method}"
|
|
assert (async_method in con_methods) is True
|
|
sync_sign = signature(getattr(ConnectionManager, method))
|
|
async_sign = signature(getattr(ConnectionManager, async_method))
|
|
assert sync_sign.parameters == async_sign.parameters
|
|
|
|
for async_method in async_methods:
|
|
if async_method in ["aclose"]:
|
|
continue
|
|
if async_method[2:].startswith("_"):
|
|
continue
|
|
|
|
assert async_method[2:] in sync_methods
|