An ebook/comic library service and web client
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.8 KiB

  1. import base64
  2. import os
  3. import random
  4. import string
  5. import tempfile
  6. from collections import namedtuple
  7. from datetime import datetime
  8. from typing import Tuple, Any, Optional
  9. import pytest
  10. import rfc3339
  11. from flask import Flask, json
  12. from flask.testing import FlaskClient, FlaskCliRunner
  13. from werkzeug.test import Client
  14. from atheneum import create_app, register_blueprints, register_error_handlers
  15. from atheneum.db import init_db
  16. from atheneum.model import User
  17. from atheneum.service import user_service
  18. def add_test_user() -> Tuple[str, str]:
  19. test_username = 'test_' + ''.join(
  20. random.choices(string.ascii_letters + string.digits, k=17)).strip()
  21. test_password = ''.join(
  22. random.choices(string.ascii_letters + string.digits, k=32)).strip()
  23. user_service.register(test_username, test_password, User.ROLE_ADMIN, False)
  24. return test_username, test_password
  25. @pytest.fixture
  26. def app() -> Flask:
  27. """Create and configure a new atheneum_app instance for each test."""
  28. # create a temporary file to isolate the database for each test
  29. db_fd, db_path = tempfile.mkstemp(suffix='.db')
  30. test_database_uri = 'sqlite:///{}'.format(db_path)
  31. # create the atheneum_app with common test config
  32. atheneum_app = create_app({
  33. 'TESTING': True,
  34. 'SQLALCHEMY_DATABASE_URI': test_database_uri,
  35. })
  36. register_blueprints(atheneum_app)
  37. register_error_handlers(atheneum_app)
  38. # create the database and load test data
  39. with atheneum_app.app_context():
  40. init_db()
  41. test_username, test_password = add_test_user()
  42. atheneum_app.config['test_username'] = test_username
  43. atheneum_app.config['test_password'] = test_password
  44. # get_db().executescript(_data_sql)
  45. yield atheneum_app
  46. # close and remove the temporary database
  47. os.close(db_fd)
  48. os.unlink(db_path)
  49. @pytest.fixture
  50. def client(app: Flask) -> FlaskClient:
  51. """A test client for the app."""
  52. return app.test_client()
  53. @pytest.fixture
  54. def runner(app: Flask) -> FlaskCliRunner:
  55. """A test runner for the app's Click commands."""
  56. return app.test_cli_runner()
  57. AuthorizationHeader = namedtuple('AuthorizationHeader', ['header', 'data'])
  58. class AuthActions(object):
  59. def __init__(self,
  60. client: Client,
  61. username: str = "",
  62. password: str = "") -> None:
  63. self._client = client
  64. self.username: str = username
  65. self.password: str = password
  66. self.token: str = ""
  67. def configure(self, username: str, password: str) -> Any:
  68. self.username = username
  69. self.password = password
  70. return self
  71. def login(
  72. self,
  73. note: Optional[str] = None,
  74. expiration_time: Optional[datetime] = None) -> Any:
  75. auth_header = self.get_authorization_header_basic()
  76. auth_json = None
  77. request_headers = {
  78. auth_header.header: auth_header.data
  79. }
  80. if note is not None or expiration_time is not None:
  81. token_options = {}
  82. if note is not None:
  83. token_options['note'] = note
  84. if expiration_time is not None:
  85. token_options['expirationTime'] = rfc3339.format(
  86. expiration_time)
  87. auth_json = json.dumps(token_options)
  88. request_headers['Content-Type'] = 'application/json'
  89. result = self._client.post(
  90. '/auth/login',
  91. headers=request_headers,
  92. data=auth_json
  93. )
  94. self.token = result.json['token']
  95. return result
  96. def bump(self) -> Any:
  97. auth_header = self.get_authorization_header_token()
  98. return self._client.post(
  99. '/auth/bump',
  100. headers={
  101. auth_header[0]: auth_header[1]
  102. }
  103. )
  104. def logout(self) -> Any:
  105. auth_header = self.get_authorization_header_token()
  106. return self._client.post(
  107. '/auth/logout',
  108. headers={
  109. auth_header[0]: auth_header[1]
  110. }
  111. )
  112. def get_authorization_header_basic(self) -> AuthorizationHeader:
  113. credentials = base64.b64encode(
  114. '{}:{}'.format(self.username, self.password).encode('utf8')) \
  115. .decode('utf8').strip()
  116. return AuthorizationHeader(
  117. 'Authorization', 'Basic {}'.format(credentials))
  118. def get_authorization_header_token(self) -> AuthorizationHeader:
  119. credentials = base64.b64encode(
  120. '{}:{}'.format(self.username, self.token).encode('utf8')) \
  121. .decode('utf8').strip()
  122. return AuthorizationHeader(
  123. 'Authorization', 'Bearer {}'.format(credentials))
  124. @pytest.fixture
  125. def auth(client: Client) -> AuthActions:
  126. return AuthActions(client,
  127. client.application.config.get('test_username'),
  128. client.application.config.get('test_password'))