You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							97 lines
						
					
					
						
							3.2 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							97 lines
						
					
					
						
							3.2 KiB
						
					
					
				| """Connection test module.""" | |
| 
 | |
| from inspect import iscoroutinefunction, signature | |
| 
 | |
| import pytest | |
| 
 | |
| from keycloak.connection import ConnectionManager | |
| from keycloak.exceptions import KeycloakConnectionError | |
| 
 | |
| 
 | |
| def test_connection_proxy() -> None: | |
|     """Test proxies of connection manager.""" | |
|     cm = ConnectionManager( | |
|         base_url="http://test.test", | |
|         proxies={"http://test.test": "http://localhost:8080"}, | |
|     ) | |
|     assert cm._s.proxies == {"http://test.test": "http://localhost:8080"} | |
| 
 | |
| 
 | |
| def test_headers() -> None: | |
|     """Test headers manipulation.""" | |
|     cm = ConnectionManager(base_url="http://test.test", headers={"H": "A"}) | |
|     assert cm.param_headers(key="H") == "A" | |
|     assert cm.param_headers(key="A") is None | |
|     cm.clean_headers() | |
|     assert cm.headers == {} | |
|     cm.add_param_headers(key="H", value="B") | |
|     assert cm.exist_param_headers(key="H") | |
|     assert not cm.exist_param_headers(key="B") | |
|     cm.del_param_headers(key="H") | |
|     assert not cm.exist_param_headers(key="H") | |
| 
 | |
| 
 | |
| def test_bad_connection() -> None: | |
|     """Test bad connection.""" | |
|     cm = ConnectionManager(base_url="http://not.real.domain") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         cm.raw_get(path="bad") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         cm.raw_delete(path="bad") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         cm.raw_post(path="bad", data={}) | |
|     with pytest.raises(KeycloakConnectionError): | |
|         cm.raw_put(path="bad", data={}) | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio | |
| async def a_test_bad_connection() -> None: | |
|     """Test bad connection.""" | |
|     cm = ConnectionManager(base_url="http://not.real.domain") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         await cm.a_raw_get(path="bad") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         await cm.a_raw_delete(path="bad") | |
|     with pytest.raises(KeycloakConnectionError): | |
|         await cm.a_raw_post(path="bad", data={}) | |
|     with pytest.raises(KeycloakConnectionError): | |
|         await cm.a_raw_put(path="bad", data={}) | |
| 
 | |
| 
 | |
| def test_counter_part() -> None: | |
|     """Test that each function has its async counter part.""" | |
|     con_methods = [ | |
|         func for func in dir(ConnectionManager) if callable(getattr(ConnectionManager, func)) | |
|     ] | |
|     sync_methods = [ | |
|         method | |
|         for method in con_methods | |
|         if not method.startswith("a_") and not method.startswith("_") | |
|     ] | |
|     async_methods = [ | |
|         method for method in con_methods if iscoroutinefunction(getattr(ConnectionManager, method)) | |
|     ] | |
| 
 | |
|     for method in sync_methods: | |
|         if method in [ | |
|             "aclose", | |
|             "add_param_headers", | |
|             "del_param_headers", | |
|             "clean_headers", | |
|             "exist_param_headers", | |
|             "param_headers", | |
|         ]: | |
|             continue | |
|         async_method = f"a_{method}" | |
|         assert (async_method in con_methods) is True | |
|         sync_sign = signature(getattr(ConnectionManager, method)) | |
|         async_sign = signature(getattr(ConnectionManager, async_method)) | |
|         assert sync_sign.parameters == async_sign.parameters | |
| 
 | |
|     for async_method in async_methods: | |
|         if async_method in ["aclose"]: | |
|             continue | |
|         if async_method[2:].startswith("_"): | |
|             continue | |
| 
 | |
|         assert async_method[2:] in sync_methods
 |