mirror of https://github.com/matrix-org/go-neb.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
8.9 KiB
347 lines
8.9 KiB
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/matrix-org/go-neb/types"
|
|
"time"
|
|
)
|
|
|
|
const schemaSQL = `
|
|
CREATE TABLE IF NOT EXISTS services (
|
|
service_id TEXT NOT NULL,
|
|
service_type TEXT NOT NULL,
|
|
service_json TEXT NOT NULL,
|
|
time_added_ms BIGINT NOT NULL,
|
|
time_updated_ms BIGINT NOT NULL,
|
|
UNIQUE(service_id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS rooms_to_services (
|
|
service_user_id TEXT NOT NULL,
|
|
room_id TEXT NOT NULL,
|
|
service_id TEXT NOT NULL,
|
|
time_added_ms BIGINT NOT NULL,
|
|
UNIQUE(service_user_id, room_id, service_id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS matrix_clients (
|
|
user_id TEXT NOT NULL,
|
|
client_json TEXT NOT NULL,
|
|
next_batch TEXT NOT NULL,
|
|
time_added_ms BIGINT NOT NULL,
|
|
time_updated_ms BIGINT NOT NULL,
|
|
UNIQUE(user_id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS auth_realms (
|
|
realm_id TEXT NOT NULL,
|
|
realm_type TEXT NOT NULL,
|
|
realm_json TEXT NOT NULL,
|
|
time_added_ms BIGINT NOT NULL,
|
|
time_updated_ms BIGINT NOT NULL,
|
|
UNIQUE(realm_id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS auth_sessions (
|
|
realm_id TEXT NOT NULL,
|
|
user_id TEXT NOT NULL,
|
|
session_json TEXT NOT NULL,
|
|
time_added_ms BIGINT NOT NULL,
|
|
time_updated_ms BIGINT NOT NULL,
|
|
UNIQUE(realm_id, user_id)
|
|
);
|
|
`
|
|
|
|
const selectServiceUserIDsSQL = `
|
|
SELECT service_user_id, room_id FROM rooms_to_services
|
|
GROUP BY service_user_id, room_id
|
|
`
|
|
|
|
// selectServiceUserIDsTxn returns a map from userIDs to lists of roomIDs.
|
|
func selectServiceUserIDsTxn(txn *sql.Tx) (map[string][]string, error) {
|
|
rows, err := txn.Query(selectServiceUserIDsSQL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make(map[string][]string)
|
|
for rows.Next() {
|
|
var uID, rID string
|
|
if err = rows.Scan(&uID, &rID); err != nil {
|
|
return nil, err
|
|
}
|
|
result[uID] = append(result[uID], rID)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
const selectMatrixClientConfigSQL = `
|
|
SELECT client_json FROM matrix_clients WHERE user_id = $1
|
|
`
|
|
|
|
func selectMatrixClientConfigTxn(txn *sql.Tx, userID string) (config types.ClientConfig, err error) {
|
|
var configJSON []byte
|
|
err = txn.QueryRow(selectMatrixClientConfigSQL, userID).Scan(&configJSON)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = json.Unmarshal(configJSON, &config)
|
|
return
|
|
}
|
|
|
|
const insertMatrixClientConfigSQL = `
|
|
INSERT INTO matrix_clients(
|
|
user_id, client_json, next_batch, time_added_ms, time_updated_ms
|
|
) VALUES ($1, $2, '', $3, $4)
|
|
`
|
|
|
|
func insertMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config types.ClientConfig) error {
|
|
t := now.UnixNano() / 1000000
|
|
configJSON, err := json.Marshal(&config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = txn.Exec(insertMatrixClientConfigSQL, config.UserID, configJSON, t, t)
|
|
return err
|
|
}
|
|
|
|
const updateMatrixClientConfigSQL = `
|
|
UPDATE matrix_clients SET client_json = $1, time_updated_ms = $2
|
|
WHERE user_id = $3
|
|
`
|
|
|
|
func updateMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config types.ClientConfig) error {
|
|
t := now.UnixNano() / 1000000
|
|
configJSON, err := json.Marshal(&config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = txn.Exec(updateMatrixClientConfigSQL, configJSON, t, config.UserID)
|
|
return err
|
|
}
|
|
|
|
const selectServiceSQL = `
|
|
SELECT service_type, service_json FROM services
|
|
WHERE service_id = $1
|
|
`
|
|
|
|
func selectServiceTxn(txn *sql.Tx, serviceID string) (types.Service, error) {
|
|
var serviceType string
|
|
var serviceJSON []byte
|
|
row := txn.QueryRow(selectServiceSQL, serviceID)
|
|
if err := row.Scan(&serviceType, &serviceJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
service := types.CreateService(serviceID, serviceType)
|
|
if service == nil {
|
|
return nil, fmt.Errorf("Cannot create services of type %s", serviceType)
|
|
}
|
|
if err := json.Unmarshal(serviceJSON, service); err != nil {
|
|
return nil, err
|
|
}
|
|
return service, nil
|
|
}
|
|
|
|
const updateServiceSQL = `
|
|
UPDATE services SET service_type=$1, service_json=$2, time_updated_ms=$3
|
|
WHERE service_id=$4
|
|
`
|
|
|
|
func updateServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
|
|
serviceJSON, err := json.Marshal(service)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
updateServiceSQL, service.ServiceType(), serviceJSON, t,
|
|
service.ServiceID(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
const insertServiceSQL = `
|
|
INSERT INTO services(
|
|
service_id, service_type, service_json, time_added_ms, time_updated_ms
|
|
) VALUES ($1, $2, $3, $4, $5)
|
|
`
|
|
|
|
func insertServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
|
|
serviceJSON, err := json.Marshal(service)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
insertServiceSQL,
|
|
service.ServiceID(), service.ServiceType(), serviceJSON, t, t,
|
|
)
|
|
return err
|
|
}
|
|
|
|
const insertRoomServiceSQL = `
|
|
INSERT INTO rooms_to_services(service_user_id, room_id, service_id, time_added_ms)
|
|
VALUES ($1, $2, $3, $4)
|
|
`
|
|
|
|
func insertRoomServiceTxn(txn *sql.Tx, now time.Time, serviceUserID, roomID, serviceID string) error {
|
|
t := now.UnixNano() / 1000000
|
|
_, err := txn.Exec(insertRoomServiceSQL, serviceUserID, roomID, serviceID, t)
|
|
return err
|
|
}
|
|
|
|
const deleteRoomServiceSQL = `
|
|
DELETE FROM rooms_to_services WHERE service_user_id=$1 AND room_id = $2 AND service_id=$3
|
|
`
|
|
|
|
func deleteRoomServiceTxn(txn *sql.Tx, serviceUserID, roomID, serviceID string) error {
|
|
_, err := txn.Exec(deleteRoomServiceSQL, serviceUserID, roomID, serviceID)
|
|
return err
|
|
}
|
|
|
|
const selectRoomServicesSQL = `
|
|
SELECT service_id FROM rooms_to_services WHERE service_user_id=$1 AND room_id=$2
|
|
`
|
|
|
|
func selectRoomServicesTxn(txn *sql.Tx, serviceUserID, roomID string) (serviceIDs []string, err error) {
|
|
rows, err := txn.Query(selectRoomServicesSQL, serviceUserID, roomID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var serviceID string
|
|
if err = rows.Scan(&serviceID); err != nil {
|
|
return
|
|
}
|
|
serviceIDs = append(serviceIDs, serviceID)
|
|
}
|
|
return
|
|
}
|
|
|
|
const insertRealmSQL = `
|
|
INSERT INTO auth_realms(
|
|
realm_id, realm_type, realm_json, time_added_ms, time_updated_ms
|
|
) VALUES ($1, $2, $3, $4, $5)
|
|
`
|
|
|
|
func insertRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
|
|
realmJSON, err := json.Marshal(realm)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
insertRealmSQL,
|
|
realm.ID(), realm.Type(), realmJSON, t, t,
|
|
)
|
|
return err
|
|
}
|
|
|
|
const selectRealmSQL = `
|
|
SELECT realm_type, realm_json FROM auth_realms WHERE realm_id = $1
|
|
`
|
|
|
|
func selectRealmTxn(txn *sql.Tx, realmID string) (types.AuthRealm, error) {
|
|
var realmType string
|
|
var realmJSON []byte
|
|
row := txn.QueryRow(selectRealmSQL, realmID)
|
|
if err := row.Scan(&realmType, &realmJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
realm := types.CreateAuthRealm(realmID, realmType)
|
|
if realm == nil {
|
|
return nil, fmt.Errorf("Cannot create realm of type %s", realmType)
|
|
}
|
|
if err := json.Unmarshal(realmJSON, realm); err != nil {
|
|
return nil, err
|
|
}
|
|
return realm, nil
|
|
}
|
|
|
|
const updateRealmSQL = `
|
|
UPDATE auth_realms SET realm_type=$1, realm_json=$2, time_updated_ms=$3
|
|
WHERE realm_id=$4
|
|
`
|
|
|
|
func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
|
|
realmJSON, err := json.Marshal(realm)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
updateRealmSQL, realm.Type(), realmJSON, t,
|
|
realm.ID(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
const insertAuthSessionSQL = `
|
|
INSERT INTO auth_sessions(
|
|
realm_id, user_id, session_json, time_added_ms, time_updated_ms
|
|
) VALUES ($1, $2, $3, $4, $5)
|
|
`
|
|
|
|
func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
|
|
sessionJSON, err := json.Marshal(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
insertAuthSessionSQL,
|
|
session.RealmID(), session.UserID(), sessionJSON, t, t,
|
|
)
|
|
return err
|
|
}
|
|
|
|
const selectAuthSessionSQL = `
|
|
SELECT realm_type, realm_json, session_json FROM auth_sessions
|
|
JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
|
|
WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2
|
|
`
|
|
|
|
func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) {
|
|
var realmType string
|
|
var realmJSON []byte
|
|
var sessionJSON []byte
|
|
row := txn.QueryRow(selectAuthSessionSQL, realmID, userID)
|
|
if err := row.Scan(&realmType, &realmJSON, &sessionJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
realm := types.CreateAuthRealm(realmID, realmType)
|
|
if realm == nil {
|
|
return nil, fmt.Errorf("Cannot create realm of type %s", realmType)
|
|
}
|
|
if err := json.Unmarshal(realmJSON, realm); err != nil {
|
|
return nil, err
|
|
}
|
|
session := realm.AuthSession(userID, realmID)
|
|
if session == nil {
|
|
return nil, fmt.Errorf("Cannot create session for given realm")
|
|
}
|
|
if err := json.Unmarshal(sessionJSON, session); err != nil {
|
|
return nil, err
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
const updateAuthSessionSQL = `
|
|
UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2
|
|
WHERE realm_id=$3 AND user_id=$4
|
|
`
|
|
|
|
func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
|
|
sessionJSON, err := json.Marshal(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
t := now.UnixNano() / 1000000
|
|
_, err = txn.Exec(
|
|
updateAuthSessionSQL, sessionJSON, t,
|
|
session.RealmID(), session.UserID(),
|
|
)
|
|
return err
|
|
}
|