You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
4.8 KiB
156 lines
4.8 KiB
import base64
|
|
import os
|
|
import random
|
|
import string
|
|
import tempfile
|
|
from collections import namedtuple
|
|
from datetime import datetime
|
|
from typing import Tuple, Any, Optional
|
|
|
|
import pytest
|
|
import rfc3339
|
|
from flask import Flask, json
|
|
from flask.testing import FlaskClient, FlaskCliRunner
|
|
from werkzeug.test import Client
|
|
|
|
from atheneum import create_app, register_blueprints, register_error_handlers
|
|
from atheneum.db import init_db
|
|
from atheneum.model import User
|
|
from atheneum.service import user_service
|
|
|
|
|
|
def add_test_user() -> Tuple[str, str]:
|
|
test_username = 'test_' + ''.join(
|
|
random.choices(string.ascii_letters + string.digits, k=17)).strip()
|
|
test_password = ''.join(
|
|
random.choices(string.ascii_letters + string.digits, k=32)).strip()
|
|
user_service.register(test_username, test_password, User.ROLE_ADMIN, False)
|
|
return test_username, test_password
|
|
|
|
|
|
@pytest.fixture
|
|
def app() -> Flask:
|
|
"""Create and configure a new atheneum_app instance for each test."""
|
|
# create a temporary file to isolate the database for each test
|
|
db_fd, db_path = tempfile.mkstemp(suffix='.db')
|
|
test_database_uri = 'sqlite:///{}'.format(db_path)
|
|
# create the atheneum_app with common test config
|
|
atheneum_app = create_app({
|
|
'TESTING': True,
|
|
'SQLALCHEMY_DATABASE_URI': test_database_uri,
|
|
})
|
|
register_blueprints(atheneum_app)
|
|
register_error_handlers(atheneum_app)
|
|
|
|
# create the database and load test data
|
|
with atheneum_app.app_context():
|
|
init_db()
|
|
test_username, test_password = add_test_user()
|
|
atheneum_app.config['test_username'] = test_username
|
|
atheneum_app.config['test_password'] = test_password
|
|
# get_db().executescript(_data_sql)
|
|
|
|
yield atheneum_app
|
|
|
|
# close and remove the temporary database
|
|
os.close(db_fd)
|
|
os.unlink(db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app: Flask) -> FlaskClient:
|
|
"""A test client for the app."""
|
|
return app.test_client()
|
|
|
|
|
|
@pytest.fixture
|
|
def runner(app: Flask) -> FlaskCliRunner:
|
|
"""A test runner for the app's Click commands."""
|
|
return app.test_cli_runner()
|
|
|
|
|
|
AuthorizationHeader = namedtuple('AuthorizationHeader', ['header', 'data'])
|
|
|
|
|
|
class AuthActions(object):
|
|
def __init__(self,
|
|
client: Client,
|
|
username: str = "",
|
|
password: str = "") -> None:
|
|
self._client = client
|
|
self.username: str = username
|
|
self.password: str = password
|
|
self.token: str = ""
|
|
|
|
def configure(self, username: str, password: str) -> Any:
|
|
self.username = username
|
|
self.password = password
|
|
return self
|
|
|
|
def login(
|
|
self,
|
|
note: Optional[str] = None,
|
|
expiration_time: Optional[datetime] = None) -> Any:
|
|
auth_header = self.get_authorization_header_basic()
|
|
auth_json = None
|
|
|
|
request_headers = {
|
|
auth_header.header: auth_header.data
|
|
}
|
|
|
|
if note is not None or expiration_time is not None:
|
|
token_options = {}
|
|
if note is not None:
|
|
token_options['note'] = note
|
|
if expiration_time is not None:
|
|
token_options['expirationTime'] = rfc3339.format(
|
|
expiration_time)
|
|
auth_json = json.dumps(token_options)
|
|
request_headers['Content-Type'] = 'application/json'
|
|
|
|
result = self._client.post(
|
|
'/auth/login',
|
|
headers=request_headers,
|
|
data=auth_json
|
|
)
|
|
self.token = result.json['token']
|
|
return result
|
|
|
|
def bump(self) -> Any:
|
|
auth_header = self.get_authorization_header_token()
|
|
return self._client.post(
|
|
'/auth/bump',
|
|
headers={
|
|
auth_header[0]: auth_header[1]
|
|
}
|
|
)
|
|
|
|
def logout(self) -> Any:
|
|
auth_header = self.get_authorization_header_token()
|
|
return self._client.post(
|
|
'/auth/logout',
|
|
headers={
|
|
auth_header[0]: auth_header[1]
|
|
}
|
|
)
|
|
|
|
def get_authorization_header_basic(self) -> AuthorizationHeader:
|
|
credentials = base64.b64encode(
|
|
'{}:{}'.format(self.username, self.password).encode('utf8')) \
|
|
.decode('utf8').strip()
|
|
return AuthorizationHeader(
|
|
'Authorization', 'Basic {}'.format(credentials))
|
|
|
|
def get_authorization_header_token(self) -> AuthorizationHeader:
|
|
credentials = base64.b64encode(
|
|
'{}:{}'.format(self.username, self.token).encode('utf8')) \
|
|
.decode('utf8').strip()
|
|
return AuthorizationHeader(
|
|
'Authorization', 'Bearer {}'.format(credentials))
|
|
|
|
|
|
@pytest.fixture
|
|
def auth(client: Client) -> AuthActions:
|
|
return AuthActions(client,
|
|
client.application.config.get('test_username'),
|
|
client.application.config.get('test_password'))
|