You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
3.8 KiB
128 lines
3.8 KiB
import base64
|
|
import os
|
|
import random
|
|
import string
|
|
import tempfile
|
|
from typing import Tuple, Any
|
|
|
|
import pytest
|
|
from flask import Flask
|
|
from flask.testing import FlaskClient, FlaskCliRunner
|
|
from werkzeug.test import Client
|
|
|
|
from atheneum import create_app, register_blueprints
|
|
from atheneum.db import init_db
|
|
from atheneum.model import User
|
|
from atheneum.service import user_service
|
|
|
|
|
|
def add_test_user() -> Tuple[str, str]:
|
|
test_username = 'test_' + ''.join(
|
|
random.choices(string.ascii_letters + string.digits, k=17)).strip()
|
|
test_password = ''.join(
|
|
random.choices(string.ascii_letters + string.digits, k=32)).strip()
|
|
user_service.register(test_username, test_password, User.ROLE_ADMIN)
|
|
return test_username, test_password
|
|
|
|
|
|
@pytest.fixture
|
|
def app() -> Flask:
|
|
"""Create and configure a new atheneum_app instance for each test."""
|
|
# create a temporary file to isolate the database for each test
|
|
db_fd, db_path = tempfile.mkstemp()
|
|
# create the atheneum_app with common test config
|
|
atheneum_app = create_app({
|
|
'TESTING': True,
|
|
'DATABASE': db_path,
|
|
})
|
|
register_blueprints(atheneum_app)
|
|
|
|
# create the database and load test data
|
|
with atheneum_app.app_context():
|
|
init_db()
|
|
test_username, test_password = add_test_user()
|
|
atheneum_app.config['test_username'] = test_username
|
|
atheneum_app.config['test_password'] = test_password
|
|
# get_db().executescript(_data_sql)
|
|
|
|
yield atheneum_app
|
|
|
|
# close and remove the temporary database
|
|
os.close(db_fd)
|
|
os.unlink(db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app: Flask) -> FlaskClient:
|
|
"""A test client for the app."""
|
|
return app.test_client()
|
|
|
|
|
|
@pytest.fixture
|
|
def runner(app: Flask) -> FlaskCliRunner:
|
|
"""A test runner for the app's Click commands."""
|
|
return app.test_cli_runner()
|
|
|
|
|
|
class AuthActions(object):
|
|
def __init__(self,
|
|
client: Client,
|
|
username: str = "",
|
|
password: str = "") -> None:
|
|
self._client = client
|
|
self.username: str = username
|
|
self.password: str = password
|
|
self.token: str = ""
|
|
|
|
def configure(self, username: str, password: str) -> Any:
|
|
self.username = username
|
|
self.password = password
|
|
return self
|
|
|
|
def login(self) -> Any:
|
|
auth_header = self.get_authorization_header_basic()
|
|
result = self._client.post(
|
|
'/auth/login',
|
|
headers={
|
|
auth_header[0]: auth_header[1]
|
|
}
|
|
)
|
|
self.token = result.json['token']
|
|
return result
|
|
|
|
def bump(self) -> Any:
|
|
auth_header = self.get_authorization_header_token()
|
|
return self._client.post(
|
|
'/auth/bump',
|
|
headers={
|
|
auth_header[0]: auth_header[1]
|
|
}
|
|
)
|
|
|
|
def logout(self) -> Any:
|
|
auth_header = self.get_authorization_header_token()
|
|
return self._client.post(
|
|
'/auth/logout',
|
|
headers={
|
|
auth_header[0]: auth_header[1]
|
|
}
|
|
)
|
|
|
|
def get_authorization_header_basic(self) -> Tuple[str, str]:
|
|
credentials = base64.b64encode(
|
|
'{}:{}'.format(self.username, self.password).encode('utf8')) \
|
|
.decode('utf8').strip()
|
|
return 'Authorization', 'Basic {}'.format(credentials)
|
|
|
|
def get_authorization_header_token(self) -> Tuple[str, str]:
|
|
credentials = base64.b64encode(
|
|
'{}:{}'.format(self.username, self.token).encode('utf8')) \
|
|
.decode('utf8').strip()
|
|
return 'Authorization', 'Token {}'.format(credentials)
|
|
|
|
|
|
@pytest.fixture
|
|
def auth(client: Client) -> AuthActions:
|
|
return AuthActions(client,
|
|
client.application.config.get('test_username'),
|
|
client.application.config.get('test_password'))
|