"""Service to handle user operations.""" import logging import random import string from datetime import datetime from typing import Optional, Dict, Callable, Any from atheneum import errors from atheneum.db import db from atheneum.model import User from atheneum.service.role_service import Role from atheneum.service.transformation_service import ( BaseTransformer, register_transformer ) from atheneum.utility import authentication_utility LOGGER = logging.getLogger(__name__) @register_transformer class UserTransformer(BaseTransformer): """Serialize User model.""" type = User def _deserializers( self) -> Dict[str, Callable[[User, Any], None]]: """Define the fields and the accompanying deserializer factory.""" return { 'name': self.deserialize_name, 'creationTime': self.deserialize_creation_time, 'lastLoginTime': self.deserialize_last_login_time, 'version': self.deserialize_version, 'role': self.deserialize_role, } def _serializers(self) -> Dict[str, Callable[[], Any]]: """Define the fields and the accompanying serializer factory.""" return { 'name': self.serialize_name, 'creationTime': self.serialize_creation_time, 'lastLoginTime': self.serialize_last_login_time, 'version': self.serialize_version, 'role': self.serialize_role, } def serialize_name(self) -> str: """User name.""" return self.model.name @staticmethod def deserialize_name(model: User, name: str) -> None: """User name.""" model.name = name def serialize_creation_time(self) -> datetime: """User creation time.""" return self.model.creation_time @staticmethod def deserialize_creation_time( model: User, creation_time: datetime) -> None: """User creation time.""" model.creation_time = creation_time def serialize_last_login_time(self) -> datetime: """User last login time.""" return self.model.last_login_time @staticmethod def deserialize_last_login_time( model: User, last_login_time: datetime) -> None: """User last login time.""" model.last_login_time = last_login_time def serialize_version(self) -> int: """User version.""" return self.model.version @staticmethod def deserialize_version(model: User, version: int) -> None: """User version.""" model.version = version def serialize_role(self) -> str: """User role.""" return self.model.role.value @staticmethod def deserialize_role(model: User, role_value: str) -> None: """User role.""" model.role = Role(role_value) def find_by_name(name: str) -> Optional[User]: """ Find a user by name. :param name: :return: """ return User.query.filter_by(name=name).first() def register(name: str, password: Optional[str], role: Optional[str]) -> User: """ Register a new user. :param name: Desired user name. Must be unique and not already registered :param password: Password to be hashed and stored for the user :param role: Role to assign the user [ROLE_USER, ROLE_ADMIN] :return: """ password = password if password is not None else ''.join( random.choices(string.ascii_letters + string.digits, k=32)) role = role if role is not None else User.ROLE_USER if find_by_name(name=name) is not None: raise errors.ValidationError('User name is already taken.') pw_hash, pw_revision = authentication_utility.get_password_hash(password) new_user = User( name=name, role=role, password_hash=pw_hash, password_revision=pw_revision, creation_time=datetime.now(), version=0) db.session.add(new_user) db.session.commit() LOGGER.info('Registered new user: %s with role: %s', name, role) return new_user def delete(user: User) -> bool: """ Delete a user. :param user: :return: """ existing_user = db.session.delete(user) if existing_user is None: db.session.commit() return True return False def update_last_login_time(user: User) -> None: """ Bump the last login time for the user. :param user: :return: """ if user is not None: user.last_login_time = datetime.now() db.session.commit() def update_password(user: User, password: str) -> None: """ Change the user password. :param user: :param password: :return: """ pw_hash, pw_revision = authentication_utility.get_password_hash( password) user.password_hash = pw_hash user.password_revision = pw_revision db.session.commit()