You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

497 lines
13 KiB

8 years ago
8 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "time"
  7. "github.com/matrix-org/go-neb/api"
  8. "github.com/matrix-org/go-neb/types"
  9. "maunium.net/go/mautrix/id"
  10. )
  11. const schemaSQL = `
  12. CREATE TABLE IF NOT EXISTS services (
  13. service_id TEXT NOT NULL,
  14. service_type TEXT NOT NULL,
  15. service_user_id TEXT NOT NULL,
  16. service_json TEXT NOT NULL,
  17. time_added_ms BIGINT NOT NULL,
  18. time_updated_ms BIGINT NOT NULL,
  19. UNIQUE(service_id)
  20. );
  21. CREATE UNIQUE INDEX IF NOT EXISTS service_id_and_user_idx ON services(service_user_id, service_id);
  22. CREATE TABLE IF NOT EXISTS matrix_clients (
  23. user_id TEXT NOT NULL,
  24. client_json TEXT NOT NULL,
  25. next_batch TEXT NOT NULL,
  26. time_added_ms BIGINT NOT NULL,
  27. time_updated_ms BIGINT NOT NULL,
  28. UNIQUE(user_id)
  29. );
  30. CREATE TABLE IF NOT EXISTS auth_realms (
  31. realm_id TEXT NOT NULL,
  32. realm_type TEXT NOT NULL,
  33. realm_json TEXT NOT NULL,
  34. time_added_ms BIGINT NOT NULL,
  35. time_updated_ms BIGINT NOT NULL,
  36. UNIQUE(realm_id)
  37. );
  38. CREATE TABLE IF NOT EXISTS auth_sessions (
  39. session_id TEXT NOT NULL,
  40. realm_id TEXT NOT NULL,
  41. user_id TEXT NOT NULL,
  42. session_json TEXT NOT NULL,
  43. time_added_ms BIGINT NOT NULL,
  44. time_updated_ms BIGINT NOT NULL,
  45. UNIQUE(realm_id, user_id),
  46. UNIQUE(realm_id, session_id)
  47. );
  48. CREATE TABLE IF NOT EXISTS bot_options (
  49. user_id TEXT NOT NULL,
  50. room_id TEXT NOT NULL,
  51. set_by_user_id TEXT NOT NULL,
  52. bot_options_json TEXT NOT NULL,
  53. time_added_ms BIGINT NOT NULL,
  54. time_updated_ms BIGINT NOT NULL,
  55. UNIQUE(user_id, room_id)
  56. );
  57. `
  58. const selectMatrixClientConfigSQL = `
  59. SELECT client_json FROM matrix_clients WHERE user_id = $1
  60. `
  61. func selectMatrixClientConfigTxn(txn *sql.Tx, userID id.UserID) (config api.ClientConfig, err error) {
  62. var configJSON []byte
  63. err = txn.QueryRow(selectMatrixClientConfigSQL, userID).Scan(&configJSON)
  64. if err != nil {
  65. return
  66. }
  67. err = json.Unmarshal(configJSON, &config)
  68. return
  69. }
  70. const selectMatrixClientConfigsSQL = `
  71. SELECT client_json FROM matrix_clients
  72. `
  73. func selectMatrixClientConfigsTxn(txn *sql.Tx) (configs []api.ClientConfig, err error) {
  74. rows, err := txn.Query(selectMatrixClientConfigsSQL)
  75. if err != nil {
  76. return
  77. }
  78. defer rows.Close()
  79. for rows.Next() {
  80. var config api.ClientConfig
  81. var configJSON []byte
  82. if err = rows.Scan(&configJSON); err != nil {
  83. return
  84. }
  85. if err = json.Unmarshal(configJSON, &config); err != nil {
  86. return
  87. }
  88. configs = append(configs, config)
  89. }
  90. return
  91. }
  92. const insertMatrixClientConfigSQL = `
  93. INSERT INTO matrix_clients(
  94. user_id, client_json, next_batch, time_added_ms, time_updated_ms
  95. ) VALUES ($1, $2, '', $3, $4)
  96. `
  97. func insertMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config api.ClientConfig) error {
  98. t := now.UnixNano() / 1000000
  99. configJSON, err := json.Marshal(&config)
  100. if err != nil {
  101. return err
  102. }
  103. _, err = txn.Exec(insertMatrixClientConfigSQL, config.UserID, configJSON, t, t)
  104. return err
  105. }
  106. const updateMatrixClientConfigSQL = `
  107. UPDATE matrix_clients SET client_json = $1, time_updated_ms = $2
  108. WHERE user_id = $3
  109. `
  110. func updateMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config api.ClientConfig) error {
  111. t := now.UnixNano() / 1000000
  112. configJSON, err := json.Marshal(&config)
  113. if err != nil {
  114. return err
  115. }
  116. _, err = txn.Exec(updateMatrixClientConfigSQL, configJSON, t, config.UserID)
  117. return err
  118. }
  119. const updateNextBatchSQL = `
  120. UPDATE matrix_clients SET next_batch = $1 WHERE user_id = $2
  121. `
  122. func updateNextBatchTxn(txn *sql.Tx, userID id.UserID, nextBatch string) error {
  123. _, err := txn.Exec(updateNextBatchSQL, nextBatch, userID)
  124. return err
  125. }
  126. const selectNextBatchSQL = `
  127. SELECT next_batch FROM matrix_clients WHERE user_id = $1
  128. `
  129. func selectNextBatchTxn(txn *sql.Tx, userID id.UserID) (string, error) {
  130. var nextBatch string
  131. row := txn.QueryRow(selectNextBatchSQL, userID)
  132. if err := row.Scan(&nextBatch); err != nil {
  133. return "", err
  134. }
  135. return nextBatch, nil
  136. }
  137. const selectServiceSQL = `
  138. SELECT service_type, service_user_id, service_json FROM services
  139. WHERE service_id = $1
  140. `
  141. func selectServiceTxn(txn *sql.Tx, serviceID string) (types.Service, error) {
  142. var serviceType string
  143. var serviceUserID id.UserID
  144. var serviceJSON []byte
  145. row := txn.QueryRow(selectServiceSQL, serviceID)
  146. if err := row.Scan(&serviceType, &serviceUserID, &serviceJSON); err != nil {
  147. return nil, err
  148. }
  149. return types.CreateService(serviceID, serviceType, serviceUserID, serviceJSON)
  150. }
  151. const updateServiceSQL = `
  152. UPDATE services SET service_type=$1, service_user_id=$2, service_json=$3, time_updated_ms=$4
  153. WHERE service_id=$5
  154. `
  155. func updateServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
  156. serviceJSON, err := json.Marshal(service)
  157. if err != nil {
  158. return err
  159. }
  160. t := now.UnixNano() / 1000000
  161. _, err = txn.Exec(
  162. updateServiceSQL, service.ServiceType(), service.ServiceUserID(), serviceJSON, t,
  163. service.ServiceID(),
  164. )
  165. return err
  166. }
  167. const insertServiceSQL = `
  168. INSERT INTO services(
  169. service_id, service_type, service_user_id, service_json, time_added_ms, time_updated_ms
  170. ) VALUES ($1, $2, $3, $4, $5, $6)
  171. `
  172. func insertServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
  173. serviceJSON, err := json.Marshal(service)
  174. if err != nil {
  175. return err
  176. }
  177. t := now.UnixNano() / 1000000
  178. _, err = txn.Exec(
  179. insertServiceSQL,
  180. service.ServiceID(), service.ServiceType(), service.ServiceUserID(), serviceJSON, t, t,
  181. )
  182. return err
  183. }
  184. const selectServicesForUserSQL = `
  185. SELECT service_id, service_type, service_json FROM services WHERE service_user_id=$1 ORDER BY service_id
  186. `
  187. func selectServicesForUserTxn(txn *sql.Tx, userID id.UserID) (srvs []types.Service, err error) {
  188. rows, err := txn.Query(selectServicesForUserSQL, userID)
  189. if err != nil {
  190. return
  191. }
  192. defer rows.Close()
  193. for rows.Next() {
  194. var s types.Service
  195. var serviceID string
  196. var serviceType string
  197. var serviceJSON []byte
  198. if err = rows.Scan(&serviceID, &serviceType, &serviceJSON); err != nil {
  199. return
  200. }
  201. s, err = types.CreateService(serviceID, serviceType, userID, serviceJSON)
  202. if err != nil {
  203. return
  204. }
  205. srvs = append(srvs, s)
  206. }
  207. return
  208. }
  209. const selectServicesByTypeSQL = `
  210. SELECT service_id, service_user_id, service_json FROM services WHERE service_type=$1 ORDER BY service_id
  211. `
  212. func selectServicesByTypeTxn(txn *sql.Tx, serviceType string) (srvs []types.Service, err error) {
  213. rows, err := txn.Query(selectServicesByTypeSQL, serviceType)
  214. if err != nil {
  215. return
  216. }
  217. defer rows.Close()
  218. for rows.Next() {
  219. var s types.Service
  220. var serviceID string
  221. var serviceUserID id.UserID
  222. var serviceJSON []byte
  223. if err = rows.Scan(&serviceID, &serviceUserID, &serviceJSON); err != nil {
  224. return
  225. }
  226. s, err = types.CreateService(serviceID, serviceType, serviceUserID, serviceJSON)
  227. if err != nil {
  228. return
  229. }
  230. srvs = append(srvs, s)
  231. }
  232. return
  233. }
  234. const deleteServiceSQL = `
  235. DELETE FROM services WHERE service_id = $1
  236. `
  237. func deleteServiceTxn(txn *sql.Tx, serviceID string) error {
  238. _, err := txn.Exec(deleteServiceSQL, serviceID)
  239. return err
  240. }
  241. const insertRealmSQL = `
  242. INSERT INTO auth_realms(
  243. realm_id, realm_type, realm_json, time_added_ms, time_updated_ms
  244. ) VALUES ($1, $2, $3, $4, $5)
  245. `
  246. func insertRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
  247. realmJSON, err := json.Marshal(realm)
  248. if err != nil {
  249. return err
  250. }
  251. t := now.UnixNano() / 1000000
  252. _, err = txn.Exec(
  253. insertRealmSQL,
  254. realm.ID(), realm.Type(), realmJSON, t, t,
  255. )
  256. return err
  257. }
  258. const selectRealmSQL = `
  259. SELECT realm_type, realm_json FROM auth_realms WHERE realm_id = $1
  260. `
  261. func selectRealmTxn(txn *sql.Tx, realmID string) (types.AuthRealm, error) {
  262. var realmType string
  263. var realmJSON []byte
  264. row := txn.QueryRow(selectRealmSQL, realmID)
  265. if err := row.Scan(&realmType, &realmJSON); err != nil {
  266. return nil, err
  267. }
  268. return types.CreateAuthRealm(realmID, realmType, realmJSON)
  269. }
  270. const selectRealmsByTypeSQL = `
  271. SELECT realm_id, realm_json FROM auth_realms WHERE realm_type = $1 ORDER BY realm_id
  272. `
  273. func selectRealmsByTypeTxn(txn *sql.Tx, realmType string) (realms []types.AuthRealm, err error) {
  274. rows, err := txn.Query(selectRealmsByTypeSQL, realmType)
  275. if err != nil {
  276. return
  277. }
  278. defer rows.Close()
  279. for rows.Next() {
  280. var realm types.AuthRealm
  281. var realmID string
  282. var realmJSON []byte
  283. if err = rows.Scan(&realmID, &realmJSON); err != nil {
  284. return
  285. }
  286. realm, err = types.CreateAuthRealm(realmID, realmType, realmJSON)
  287. if err != nil {
  288. return
  289. }
  290. realms = append(realms, realm)
  291. }
  292. return
  293. }
  294. const updateRealmSQL = `
  295. UPDATE auth_realms SET realm_type=$1, realm_json=$2, time_updated_ms=$3
  296. WHERE realm_id=$4
  297. `
  298. func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
  299. realmJSON, err := json.Marshal(realm)
  300. if err != nil {
  301. return err
  302. }
  303. t := now.UnixNano() / 1000000
  304. _, err = txn.Exec(
  305. updateRealmSQL, realm.Type(), realmJSON, t,
  306. realm.ID(),
  307. )
  308. return err
  309. }
  310. const insertAuthSessionSQL = `
  311. INSERT INTO auth_sessions(
  312. session_id, realm_id, user_id, session_json, time_added_ms, time_updated_ms
  313. ) VALUES ($1, $2, $3, $4, $5, $6)
  314. `
  315. func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
  316. sessionJSON, err := json.Marshal(session)
  317. if err != nil {
  318. return err
  319. }
  320. t := now.UnixNano() / 1000000
  321. _, err = txn.Exec(
  322. insertAuthSessionSQL,
  323. session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t,
  324. )
  325. return err
  326. }
  327. const deleteAuthSessionSQL = `
  328. DELETE FROM auth_sessions WHERE realm_id=$1 AND user_id=$2
  329. `
  330. func deleteAuthSessionTxn(txn *sql.Tx, realmID string, userID id.UserID) error {
  331. _, err := txn.Exec(deleteAuthSessionSQL, realmID, userID)
  332. return err
  333. }
  334. const selectAuthSessionByUserSQL = `
  335. SELECT session_id, realm_type, realm_json, session_json FROM auth_sessions
  336. JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
  337. WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2
  338. `
  339. func selectAuthSessionByUserTxn(txn *sql.Tx, realmID string, userID id.UserID) (types.AuthSession, error) {
  340. var id string
  341. var realmType string
  342. var realmJSON []byte
  343. var sessionJSON []byte
  344. row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID)
  345. if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil {
  346. return nil, err
  347. }
  348. realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON)
  349. if err != nil {
  350. return nil, err
  351. }
  352. session := realm.AuthSession(id, userID, realmID)
  353. if session == nil {
  354. return nil, fmt.Errorf("Cannot create session for given realm")
  355. }
  356. if err := json.Unmarshal(sessionJSON, session); err != nil {
  357. return nil, err
  358. }
  359. return session, nil
  360. }
  361. const selectAuthSessionByIDSQL = `
  362. SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions
  363. JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
  364. WHERE auth_sessions.realm_id = $1 AND auth_sessions.session_id = $2
  365. `
  366. func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, sid string) (types.AuthSession, error) {
  367. var userID id.UserID
  368. var realmType string
  369. var realmJSON []byte
  370. var sessionJSON []byte
  371. row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, sid)
  372. if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil {
  373. return nil, err
  374. }
  375. realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON)
  376. if err != nil {
  377. return nil, err
  378. }
  379. session := realm.AuthSession(sid, userID, realmID)
  380. if session == nil {
  381. return nil, fmt.Errorf("Cannot create session for given realm")
  382. }
  383. if err := json.Unmarshal(sessionJSON, session); err != nil {
  384. return nil, err
  385. }
  386. return session, nil
  387. }
  388. const updateAuthSessionSQL = `
  389. UPDATE auth_sessions SET session_id=$1, session_json=$2, time_updated_ms=$3
  390. WHERE realm_id=$4 AND user_id=$5
  391. `
  392. func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
  393. sessionJSON, err := json.Marshal(session)
  394. if err != nil {
  395. return err
  396. }
  397. t := now.UnixNano() / 1000000
  398. _, err = txn.Exec(
  399. updateAuthSessionSQL, session.ID(), sessionJSON, t,
  400. session.RealmID(), session.UserID(),
  401. )
  402. return err
  403. }
  404. const selectBotOptionsSQL = `
  405. SELECT bot_options_json, set_by_user_id FROM bot_options WHERE user_id = $1 AND room_id = $2
  406. `
  407. func selectBotOptionsTxn(txn *sql.Tx, userID id.UserID, roomID id.RoomID) (opts types.BotOptions, err error) {
  408. var optionsJSON []byte
  409. err = txn.QueryRow(selectBotOptionsSQL, userID, roomID).Scan(&optionsJSON, &opts.SetByUserID)
  410. if err != nil {
  411. return
  412. }
  413. err = json.Unmarshal(optionsJSON, &opts.Options)
  414. return
  415. }
  416. const insertBotOptionsSQL = `
  417. INSERT INTO bot_options(
  418. user_id, room_id, bot_options_json, set_by_user_id, time_added_ms, time_updated_ms
  419. ) VALUES ($1, $2, $3, $4, $5, $6)
  420. `
  421. func insertBotOptionsTxn(txn *sql.Tx, now time.Time, opts types.BotOptions) error {
  422. t := now.UnixNano() / 1000000
  423. optsJSON, err := json.Marshal(&opts.Options)
  424. if err != nil {
  425. return err
  426. }
  427. _, err = txn.Exec(insertBotOptionsSQL, opts.UserID, opts.RoomID, optsJSON, opts.SetByUserID, t, t)
  428. return err
  429. }
  430. const updateBotOptionsSQL = `
  431. UPDATE bot_options SET bot_options_json = $1, set_by_user_id = $2, time_updated_ms = $3
  432. WHERE user_id = $4 AND room_id = $5
  433. `
  434. func updateBotOptionsTxn(txn *sql.Tx, now time.Time, opts types.BotOptions) error {
  435. t := now.UnixNano() / 1000000
  436. optsJSON, err := json.Marshal(&opts.Options)
  437. if err != nil {
  438. return err
  439. }
  440. _, err = txn.Exec(updateBotOptionsSQL, optsJSON, opts.SetByUserID, t, opts.UserID, opts.RoomID)
  441. return err
  442. }