diff --git a/src/keycloak/keycloak_uma.py b/src/keycloak/keycloak_uma.py index 0075689..f4683b3 100644 --- a/src/keycloak/keycloak_uma.py +++ b/src/keycloak/keycloak_uma.py @@ -30,6 +30,8 @@ import json from typing import Iterable from urllib.parse import quote_plus +from async_property import async_property + from .connection import ConnectionManager from .exceptions import ( KeycloakDeleteError, @@ -56,9 +58,6 @@ class KeycloakUMA: :type connection: KeycloakOpenIDConnection """ self.connection = connection - custom_headers = self.connection.custom_headers or {} - custom_headers.update({"Content-Type": "application/json"}) - self.connection.custom_headers = custom_headers self._well_known = None def _fetch_well_known(self): @@ -96,6 +95,17 @@ class KeycloakUMA: self._well_known = self._fetch_well_known() return self._well_known + @async_property + async def a_uma_well_known(self): + """Get the well_known UMA2 config async. + + :returns: It lists endpoints and other configuration options relevant + :rtype: dict + """ + if not self._well_known: + self._well_known = await self.a__fetch_well_known() + return self._well_known + def resource_set_create(self, payload): """Create a resource set. @@ -441,7 +451,8 @@ class KeycloakUMA: :rtype: dict """ data_raw = await self.connection.a_raw_post( - self.uma_well_known["resource_registration_endpoint"], data=json.dumps(payload) + (await self.a_uma_well_known)["resource_registration_endpoint"], + data=json.dumps(payload), ) return raise_error_from_response(data_raw, KeycloakPostError, expected_codes=[201]) @@ -462,7 +473,8 @@ class KeycloakUMA: :rtype: dict """ url = self.format_url( - self.uma_well_known["resource_registration_endpoint"] + "/{id}", id=resource_id + (await self.a_uma_well_known)["resource_registration_endpoint"] + "/{id}", + id=resource_id, ) data_raw = await self.connection.a_raw_put(url, data=json.dumps(payload)) return raise_error_from_response(data_raw, KeycloakPutError, expected_codes=[204]) @@ -482,7 +494,8 @@ class KeycloakUMA: :rtype: dict """ url = self.format_url( - self.uma_well_known["resource_registration_endpoint"] + "/{id}", id=resource_id + (await self.a_uma_well_known)["resource_registration_endpoint"] + "/{id}", + id=resource_id, ) data_raw = await self.connection.a_raw_get(url) return raise_error_from_response(data_raw, KeycloakGetError, expected_codes=[200]) @@ -499,7 +512,8 @@ class KeycloakUMA: :rtype: dict """ url = self.format_url( - self.uma_well_known["resource_registration_endpoint"] + "/{id}", id=resource_id + (await self.a_uma_well_known)["resource_registration_endpoint"] + "/{id}", + id=resource_id, ) data_raw = await self.connection.a_raw_delete(url) return raise_error_from_response(data_raw, KeycloakDeleteError, expected_codes=[204]) @@ -558,7 +572,7 @@ class KeycloakUMA: query["max"] = maximum data_raw = await self.connection.a_raw_get( - self.uma_well_known["resource_registration_endpoint"], **query + (await self.a_uma_well_known)["resource_registration_endpoint"], **query ) return raise_error_from_response(data_raw, KeycloakGetError, expected_codes=[200]) @@ -611,7 +625,7 @@ class KeycloakUMA: ] data_raw = await self.connection.a_raw_post( - self.uma_well_known["permission_endpoint"], data=json.dumps(payload) + (await self.a_uma_well_known)["permission_endpoint"], data=json.dumps(payload) ) return raise_error_from_response(data_raw, KeycloakPostError) @@ -645,7 +659,9 @@ class KeycloakUMA: connection = ConnectionManager(self.connection.base_url) connection.add_param_headers("Authorization", "Bearer " + token) connection.add_param_headers("Content-Type", "application/x-www-form-urlencoded") - data_raw = await connection.a_raw_post(self.uma_well_known["token_endpoint"], data=payload) + data_raw = await connection.a_raw_post( + (await self.a_uma_well_known)["token_endpoint"], data=payload + ) try: data = raise_error_from_response(data_raw, KeycloakPostError) except KeycloakPostError: @@ -667,7 +683,8 @@ class KeycloakUMA: :rtype: dict """ data_raw = await self.connection.a_raw_post( - self.uma_well_known["policy_endpoint"] + f"/{resource_id}", data=json.dumps(payload) + (await self.a_uma_well_known)["policy_endpoint"] + f"/{resource_id}", + data=json.dumps(payload), ) return raise_error_from_response(data_raw, KeycloakPostError) @@ -685,7 +702,8 @@ class KeycloakUMA: :rtype: dict """ data_raw = await self.connection.a_raw_put( - self.uma_well_known["policy_endpoint"] + f"/{policy_id}", data=json.dumps(payload) + (await self.a_uma_well_known)["policy_endpoint"] + f"/{policy_id}", + data=json.dumps(payload), ) return raise_error_from_response(data_raw, KeycloakPutError) @@ -701,7 +719,7 @@ class KeycloakUMA: :rtype: dict """ data_raw = await self.connection.a_raw_delete( - self.uma_well_known["policy_endpoint"] + f"/{policy_id}" + (await self.a_uma_well_known)["policy_endpoint"] + f"/{policy_id}" ) return raise_error_from_response(data_raw, KeycloakDeleteError) @@ -743,5 +761,7 @@ class KeycloakUMA: if maximum >= 0: query["max"] = maximum - data_raw = await self.connection.a_raw_get(self.uma_well_known["policy_endpoint"], **query) + data_raw = await self.connection.a_raw_get( + (await self.a_uma_well_known)["policy_endpoint"], **query + ) return raise_error_from_response(data_raw, KeycloakGetError) diff --git a/src/keycloak/openid_connection.py b/src/keycloak/openid_connection.py index b484e44..30ff4d9 100644 --- a/src/keycloak/openid_connection.py +++ b/src/keycloak/openid_connection.py @@ -103,6 +103,7 @@ class KeycloakOpenIDConnection(ConnectionManager): # token is renewed when it hits 90% of its lifetime. This is to account for any possible # clock skew. self.token_lifetime_fraction = 0.9 + self.headers = {} self.server_url = server_url self.username = username self.password = password @@ -114,18 +115,8 @@ class KeycloakOpenIDConnection(ConnectionManager): self.client_secret_key = client_secret_key self.user_realm_name = user_realm_name self.timeout = timeout - self.headers = {} self.custom_headers = custom_headers - - if self.token is None: - self.get_token() - - if self.token is not None: - self.headers = { - **self.headers, - "Authorization": "Bearer " + self.token.get("access_token"), - "Content-Type": "application/json", - } + self.headers = {**self.headers, "Content-Type": "application/json"} super().__init__( base_url=self.server_url, headers=self.headers, timeout=60, verify=self.verify @@ -237,6 +228,8 @@ class KeycloakOpenIDConnection(ConnectionManager): self._expires_at = datetime.now() + timedelta( seconds=int(self.token_lifetime_fraction * self.token["expires_in"] if value else 0) ) + if value is not None: + self.add_param_headers("Authorization", "Bearer " + value.get("access_token")) @property def expires_at(self): @@ -345,8 +338,6 @@ class KeycloakOpenIDConnection(ConnectionManager): else: raise - self.add_param_headers("Authorization", "Bearer " + self.token.get("access_token")) - def _refresh_if_required(self): if datetime.now() >= self.expires_at: self.refresh_token() @@ -419,6 +410,53 @@ class KeycloakOpenIDConnection(ConnectionManager): r = super().raw_delete(*args, **kwargs) return r + async def a_get_token(self): + """Get admin token. + + The admin token is then set in the `token` attribute. + """ + grant_type = [] + if self.username and self.password: + grant_type.append("password") + elif self.client_secret_key: + grant_type.append("client_credentials") + + if grant_type: + self.token = await self.keycloak_openid.a_token( + self.username, self.password, grant_type=grant_type, totp=self.totp + ) + else: + self.token = None + + async def a_refresh_token(self): + """Refresh the token. + + :raises KeycloakPostError: In case the refresh token request failed. + """ + refresh_token = self.token.get("refresh_token", None) if self.token else None + if refresh_token is None: + await self.a_get_token() + else: + try: + self.token = await self.keycloak_openid.a_refresh_token(refresh_token) + except KeycloakPostError as e: + list_errors = [ + b"Refresh token expired", + b"Token is not active", + b"Session not active", + ] + if e.response_code == 400 and any(err in e.response_body for err in list_errors): + await self.a_get_token() + else: + raise + + self.add_param_headers("Authorization", "Bearer " + self.token.get("access_token")) + + async def a__refresh_if_required(self): + """Refresh the token if it is expired.""" + if datetime.now() >= self.expires_at: + await self.a_refresh_token() + async def a_raw_get(self, *args, **kwargs): """Call connection.raw_get. @@ -432,7 +470,7 @@ class KeycloakOpenIDConnection(ConnectionManager): :returns: Response :rtype: Response """ - self._refresh_if_required() + await self.a__refresh_if_required() r = await super().a_raw_get(*args, **kwargs) return r @@ -449,7 +487,7 @@ class KeycloakOpenIDConnection(ConnectionManager): :returns: Response :rtype: Response """ - self._refresh_if_required() + await self.a__refresh_if_required() r = await super().a_raw_post(*args, **kwargs) return r @@ -466,7 +504,7 @@ class KeycloakOpenIDConnection(ConnectionManager): :returns: Response :rtype: Response """ - self._refresh_if_required() + await self.a__refresh_if_required() r = await super().a_raw_put(*args, **kwargs) return r @@ -483,6 +521,6 @@ class KeycloakOpenIDConnection(ConnectionManager): :returns: Response :rtype: Response """ - self._refresh_if_required() + await self.a__refresh_if_required() r = await super().a_raw_delete(*args, **kwargs) return r