diff --git a/tildes/tests/test_user.py b/tildes/tests/test_user.py index dafd1d7..639b425 100644 --- a/tildes/tests/test_user.py +++ b/tildes/tests/test_user.py @@ -118,6 +118,24 @@ def test_change_password_to_username(session_user): session_user.change_password("session user password", session_user.username) +def test_user_email_check(): + """Ensure checking a user's email address works correctly.""" + user = User("Some_User", "Some_Password") + user.email_address = "some_user@example.com" + + assert user.is_correct_email_address("some_user@example.com") + + assert not user.is_correct_email_address("someuser@example.com") + + +def test_user_email_check_case_insensitive(): + """Ensure the user email address check isn't case-sensitive.""" + user = User("Some_User", "Some_Password") + user.email_address = "Some_User@example.com" + + assert user.is_correct_email_address("some_user@example.com") + + def test_deleted_user_no_message_permission(): """Ensure nobody can message a deleted user.""" deleted_user = User("Deleted_User", "password") diff --git a/tildes/tildes/models/user/user.py b/tildes/tildes/models/user/user.py index 81c16b9..c6db4b4 100644 --- a/tildes/tildes/models/user/user.py +++ b/tildes/tildes/models/user/user.py @@ -282,6 +282,13 @@ class User(DatabaseModel): value = value.lower() self.email_address_hash = hash_string(value) + def is_correct_email_address(self, email_address: str) -> bool: + """Check if the email address is correct for this user.""" + if not self.email_address_hash: + raise ValueError("User has not set an email address") + + return is_match_for_hash(email_address.lower(), self.email_address_hash) + @property def num_unread_total(self) -> int: """Return total number of unread items (notifications + messages)."""