You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

496 lines
13 KiB

8 years ago
8 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "time"
  7. "github.com/matrix-org/go-neb/api"
  8. "github.com/matrix-org/go-neb/types"
  9. )
  10. const schemaSQL = `
  11. CREATE TABLE IF NOT EXISTS services (
  12. service_id TEXT NOT NULL,
  13. service_type TEXT NOT NULL,
  14. service_user_id TEXT NOT NULL,
  15. service_json TEXT NOT NULL,
  16. time_added_ms BIGINT NOT NULL,
  17. time_updated_ms BIGINT NOT NULL,
  18. UNIQUE(service_id)
  19. );
  20. CREATE UNIQUE INDEX IF NOT EXISTS service_id_and_user_idx ON services(service_user_id, service_id);
  21. CREATE TABLE IF NOT EXISTS matrix_clients (
  22. user_id TEXT NOT NULL,
  23. client_json TEXT NOT NULL,
  24. next_batch TEXT NOT NULL,
  25. time_added_ms BIGINT NOT NULL,
  26. time_updated_ms BIGINT NOT NULL,
  27. UNIQUE(user_id)
  28. );
  29. CREATE TABLE IF NOT EXISTS auth_realms (
  30. realm_id TEXT NOT NULL,
  31. realm_type TEXT NOT NULL,
  32. realm_json TEXT NOT NULL,
  33. time_added_ms BIGINT NOT NULL,
  34. time_updated_ms BIGINT NOT NULL,
  35. UNIQUE(realm_id)
  36. );
  37. CREATE TABLE IF NOT EXISTS auth_sessions (
  38. session_id TEXT NOT NULL,
  39. realm_id TEXT NOT NULL,
  40. user_id TEXT NOT NULL,
  41. session_json TEXT NOT NULL,
  42. time_added_ms BIGINT NOT NULL,
  43. time_updated_ms BIGINT NOT NULL,
  44. UNIQUE(realm_id, user_id),
  45. UNIQUE(realm_id, session_id)
  46. );
  47. CREATE TABLE IF NOT EXISTS bot_options (
  48. user_id TEXT NOT NULL,
  49. room_id TEXT NOT NULL,
  50. set_by_user_id TEXT NOT NULL,
  51. bot_options_json TEXT NOT NULL,
  52. time_added_ms BIGINT NOT NULL,
  53. time_updated_ms BIGINT NOT NULL,
  54. UNIQUE(user_id, room_id)
  55. );
  56. `
  57. const selectMatrixClientConfigSQL = `
  58. SELECT client_json FROM matrix_clients WHERE user_id = $1
  59. `
  60. func selectMatrixClientConfigTxn(txn *sql.Tx, userID string) (config api.ClientConfig, err error) {
  61. var configJSON []byte
  62. err = txn.QueryRow(selectMatrixClientConfigSQL, userID).Scan(&configJSON)
  63. if err != nil {
  64. return
  65. }
  66. err = json.Unmarshal(configJSON, &config)
  67. return
  68. }
  69. const selectMatrixClientConfigsSQL = `
  70. SELECT client_json FROM matrix_clients
  71. `
  72. func selectMatrixClientConfigsTxn(txn *sql.Tx) (configs []api.ClientConfig, err error) {
  73. rows, err := txn.Query(selectMatrixClientConfigsSQL)
  74. if err != nil {
  75. return
  76. }
  77. defer rows.Close()
  78. for rows.Next() {
  79. var config api.ClientConfig
  80. var configJSON []byte
  81. if err = rows.Scan(&configJSON); err != nil {
  82. return
  83. }
  84. if err = json.Unmarshal(configJSON, &config); err != nil {
  85. return
  86. }
  87. configs = append(configs, config)
  88. }
  89. return
  90. }
  91. const insertMatrixClientConfigSQL = `
  92. INSERT INTO matrix_clients(
  93. user_id, client_json, next_batch, time_added_ms, time_updated_ms
  94. ) VALUES ($1, $2, '', $3, $4)
  95. `
  96. func insertMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config api.ClientConfig) error {
  97. t := now.UnixNano() / 1000000
  98. configJSON, err := json.Marshal(&config)
  99. if err != nil {
  100. return err
  101. }
  102. _, err = txn.Exec(insertMatrixClientConfigSQL, config.UserID, configJSON, t, t)
  103. return err
  104. }
  105. const updateMatrixClientConfigSQL = `
  106. UPDATE matrix_clients SET client_json = $1, time_updated_ms = $2
  107. WHERE user_id = $3
  108. `
  109. func updateMatrixClientConfigTxn(txn *sql.Tx, now time.Time, config api.ClientConfig) error {
  110. t := now.UnixNano() / 1000000
  111. configJSON, err := json.Marshal(&config)
  112. if err != nil {
  113. return err
  114. }
  115. _, err = txn.Exec(updateMatrixClientConfigSQL, configJSON, t, config.UserID)
  116. return err
  117. }
  118. const updateNextBatchSQL = `
  119. UPDATE matrix_clients SET next_batch = $1 WHERE user_id = $2
  120. `
  121. func updateNextBatchTxn(txn *sql.Tx, userID, nextBatch string) error {
  122. _, err := txn.Exec(updateNextBatchSQL, nextBatch, userID)
  123. return err
  124. }
  125. const selectNextBatchSQL = `
  126. SELECT next_batch FROM matrix_clients WHERE user_id = $1
  127. `
  128. func selectNextBatchTxn(txn *sql.Tx, userID string) (string, error) {
  129. var nextBatch string
  130. row := txn.QueryRow(selectNextBatchSQL, userID)
  131. if err := row.Scan(&nextBatch); err != nil {
  132. return "", err
  133. }
  134. return nextBatch, nil
  135. }
  136. const selectServiceSQL = `
  137. SELECT service_type, service_user_id, service_json FROM services
  138. WHERE service_id = $1
  139. `
  140. func selectServiceTxn(txn *sql.Tx, serviceID string) (types.Service, error) {
  141. var serviceType string
  142. var serviceUserID string
  143. var serviceJSON []byte
  144. row := txn.QueryRow(selectServiceSQL, serviceID)
  145. if err := row.Scan(&serviceType, &serviceUserID, &serviceJSON); err != nil {
  146. return nil, err
  147. }
  148. return types.CreateService(serviceID, serviceType, serviceUserID, serviceJSON)
  149. }
  150. const updateServiceSQL = `
  151. UPDATE services SET service_type=$1, service_user_id=$2, service_json=$3, time_updated_ms=$4
  152. WHERE service_id=$5
  153. `
  154. func updateServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
  155. serviceJSON, err := json.Marshal(service)
  156. if err != nil {
  157. return err
  158. }
  159. t := now.UnixNano() / 1000000
  160. _, err = txn.Exec(
  161. updateServiceSQL, service.ServiceType(), service.ServiceUserID(), serviceJSON, t,
  162. service.ServiceID(),
  163. )
  164. return err
  165. }
  166. const insertServiceSQL = `
  167. INSERT INTO services(
  168. service_id, service_type, service_user_id, service_json, time_added_ms, time_updated_ms
  169. ) VALUES ($1, $2, $3, $4, $5, $6)
  170. `
  171. func insertServiceTxn(txn *sql.Tx, now time.Time, service types.Service) error {
  172. serviceJSON, err := json.Marshal(service)
  173. if err != nil {
  174. return err
  175. }
  176. t := now.UnixNano() / 1000000
  177. _, err = txn.Exec(
  178. insertServiceSQL,
  179. service.ServiceID(), service.ServiceType(), service.ServiceUserID(), serviceJSON, t, t,
  180. )
  181. return err
  182. }
  183. const selectServicesForUserSQL = `
  184. SELECT service_id, service_type, service_json FROM services WHERE service_user_id=$1 ORDER BY service_id
  185. `
  186. func selectServicesForUserTxn(txn *sql.Tx, userID string) (srvs []types.Service, err error) {
  187. rows, err := txn.Query(selectServicesForUserSQL, userID)
  188. if err != nil {
  189. return
  190. }
  191. defer rows.Close()
  192. for rows.Next() {
  193. var s types.Service
  194. var serviceID string
  195. var serviceType string
  196. var serviceJSON []byte
  197. if err = rows.Scan(&serviceID, &serviceType, &serviceJSON); err != nil {
  198. return
  199. }
  200. s, err = types.CreateService(serviceID, serviceType, userID, serviceJSON)
  201. if err != nil {
  202. return
  203. }
  204. srvs = append(srvs, s)
  205. }
  206. return
  207. }
  208. const selectServicesByTypeSQL = `
  209. SELECT service_id, service_user_id, service_json FROM services WHERE service_type=$1 ORDER BY service_id
  210. `
  211. func selectServicesByTypeTxn(txn *sql.Tx, serviceType string) (srvs []types.Service, err error) {
  212. rows, err := txn.Query(selectServicesByTypeSQL, serviceType)
  213. if err != nil {
  214. return
  215. }
  216. defer rows.Close()
  217. for rows.Next() {
  218. var s types.Service
  219. var serviceID string
  220. var serviceUserID string
  221. var serviceJSON []byte
  222. if err = rows.Scan(&serviceID, &serviceUserID, &serviceJSON); err != nil {
  223. return
  224. }
  225. s, err = types.CreateService(serviceID, serviceType, serviceUserID, serviceJSON)
  226. if err != nil {
  227. return
  228. }
  229. srvs = append(srvs, s)
  230. }
  231. return
  232. }
  233. const deleteServiceSQL = `
  234. DELETE FROM services WHERE service_id = $1
  235. `
  236. func deleteServiceTxn(txn *sql.Tx, serviceID string) error {
  237. _, err := txn.Exec(deleteServiceSQL, serviceID)
  238. return err
  239. }
  240. const insertRealmSQL = `
  241. INSERT INTO auth_realms(
  242. realm_id, realm_type, realm_json, time_added_ms, time_updated_ms
  243. ) VALUES ($1, $2, $3, $4, $5)
  244. `
  245. func insertRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
  246. realmJSON, err := json.Marshal(realm)
  247. if err != nil {
  248. return err
  249. }
  250. t := now.UnixNano() / 1000000
  251. _, err = txn.Exec(
  252. insertRealmSQL,
  253. realm.ID(), realm.Type(), realmJSON, t, t,
  254. )
  255. return err
  256. }
  257. const selectRealmSQL = `
  258. SELECT realm_type, realm_json FROM auth_realms WHERE realm_id = $1
  259. `
  260. func selectRealmTxn(txn *sql.Tx, realmID string) (types.AuthRealm, error) {
  261. var realmType string
  262. var realmJSON []byte
  263. row := txn.QueryRow(selectRealmSQL, realmID)
  264. if err := row.Scan(&realmType, &realmJSON); err != nil {
  265. return nil, err
  266. }
  267. return types.CreateAuthRealm(realmID, realmType, realmJSON)
  268. }
  269. const selectRealmsByTypeSQL = `
  270. SELECT realm_id, realm_json FROM auth_realms WHERE realm_type = $1 ORDER BY realm_id
  271. `
  272. func selectRealmsByTypeTxn(txn *sql.Tx, realmType string) (realms []types.AuthRealm, err error) {
  273. rows, err := txn.Query(selectRealmsByTypeSQL, realmType)
  274. if err != nil {
  275. return
  276. }
  277. defer rows.Close()
  278. for rows.Next() {
  279. var realm types.AuthRealm
  280. var realmID string
  281. var realmJSON []byte
  282. if err = rows.Scan(&realmID, &realmJSON); err != nil {
  283. return
  284. }
  285. realm, err = types.CreateAuthRealm(realmID, realmType, realmJSON)
  286. if err != nil {
  287. return
  288. }
  289. realms = append(realms, realm)
  290. }
  291. return
  292. }
  293. const updateRealmSQL = `
  294. UPDATE auth_realms SET realm_type=$1, realm_json=$2, time_updated_ms=$3
  295. WHERE realm_id=$4
  296. `
  297. func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
  298. realmJSON, err := json.Marshal(realm)
  299. if err != nil {
  300. return err
  301. }
  302. t := now.UnixNano() / 1000000
  303. _, err = txn.Exec(
  304. updateRealmSQL, realm.Type(), realmJSON, t,
  305. realm.ID(),
  306. )
  307. return err
  308. }
  309. const insertAuthSessionSQL = `
  310. INSERT INTO auth_sessions(
  311. session_id, realm_id, user_id, session_json, time_added_ms, time_updated_ms
  312. ) VALUES ($1, $2, $3, $4, $5, $6)
  313. `
  314. func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
  315. sessionJSON, err := json.Marshal(session)
  316. if err != nil {
  317. return err
  318. }
  319. t := now.UnixNano() / 1000000
  320. _, err = txn.Exec(
  321. insertAuthSessionSQL,
  322. session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t,
  323. )
  324. return err
  325. }
  326. const deleteAuthSessionSQL = `
  327. DELETE FROM auth_sessions WHERE realm_id=$1 AND user_id=$2
  328. `
  329. func deleteAuthSessionTxn(txn *sql.Tx, realmID, userID string) error {
  330. _, err := txn.Exec(deleteAuthSessionSQL, realmID, userID)
  331. return err
  332. }
  333. const selectAuthSessionByUserSQL = `
  334. SELECT session_id, realm_type, realm_json, session_json FROM auth_sessions
  335. JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
  336. WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2
  337. `
  338. func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) {
  339. var id string
  340. var realmType string
  341. var realmJSON []byte
  342. var sessionJSON []byte
  343. row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID)
  344. if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil {
  345. return nil, err
  346. }
  347. realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON)
  348. if err != nil {
  349. return nil, err
  350. }
  351. session := realm.AuthSession(id, userID, realmID)
  352. if session == nil {
  353. return nil, fmt.Errorf("Cannot create session for given realm")
  354. }
  355. if err := json.Unmarshal(sessionJSON, session); err != nil {
  356. return nil, err
  357. }
  358. return session, nil
  359. }
  360. const selectAuthSessionByIDSQL = `
  361. SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions
  362. JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
  363. WHERE auth_sessions.realm_id = $1 AND auth_sessions.session_id = $2
  364. `
  365. func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSession, error) {
  366. var userID string
  367. var realmType string
  368. var realmJSON []byte
  369. var sessionJSON []byte
  370. row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, id)
  371. if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil {
  372. return nil, err
  373. }
  374. realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON)
  375. if err != nil {
  376. return nil, err
  377. }
  378. session := realm.AuthSession(id, userID, realmID)
  379. if session == nil {
  380. return nil, fmt.Errorf("Cannot create session for given realm")
  381. }
  382. if err := json.Unmarshal(sessionJSON, session); err != nil {
  383. return nil, err
  384. }
  385. return session, nil
  386. }
  387. const updateAuthSessionSQL = `
  388. UPDATE auth_sessions SET session_id=$1, session_json=$2, time_updated_ms=$3
  389. WHERE realm_id=$4 AND user_id=$5
  390. `
  391. func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
  392. sessionJSON, err := json.Marshal(session)
  393. if err != nil {
  394. return err
  395. }
  396. t := now.UnixNano() / 1000000
  397. _, err = txn.Exec(
  398. updateAuthSessionSQL, session.ID(), sessionJSON, t,
  399. session.RealmID(), session.UserID(),
  400. )
  401. return err
  402. }
  403. const selectBotOptionsSQL = `
  404. SELECT bot_options_json, set_by_user_id FROM bot_options WHERE user_id = $1 AND room_id = $2
  405. `
  406. func selectBotOptionsTxn(txn *sql.Tx, userID, roomID string) (opts types.BotOptions, err error) {
  407. var optionsJSON []byte
  408. err = txn.QueryRow(selectBotOptionsSQL, userID, roomID).Scan(&optionsJSON, &opts.SetByUserID)
  409. if err != nil {
  410. return
  411. }
  412. err = json.Unmarshal(optionsJSON, &opts.Options)
  413. return
  414. }
  415. const insertBotOptionsSQL = `
  416. INSERT INTO bot_options(
  417. user_id, room_id, bot_options_json, set_by_user_id, time_added_ms, time_updated_ms
  418. ) VALUES ($1, $2, $3, $4, $5, $6)
  419. `
  420. func insertBotOptionsTxn(txn *sql.Tx, now time.Time, opts types.BotOptions) error {
  421. t := now.UnixNano() / 1000000
  422. optsJSON, err := json.Marshal(&opts.Options)
  423. if err != nil {
  424. return err
  425. }
  426. _, err = txn.Exec(insertBotOptionsSQL, opts.UserID, opts.RoomID, optsJSON, opts.SetByUserID, t, t)
  427. return err
  428. }
  429. const updateBotOptionsSQL = `
  430. UPDATE bot_options SET bot_options_json = $1, set_by_user_id = $2, time_updated_ms = $3
  431. WHERE user_id = $4 AND room_id = $5
  432. `
  433. func updateBotOptionsTxn(txn *sql.Tx, now time.Time, opts types.BotOptions) error {
  434. t := now.UnixNano() / 1000000
  435. optsJSON, err := json.Marshal(&opts.Options)
  436. if err != nil {
  437. return err
  438. }
  439. _, err = txn.Exec(updateBotOptionsSQL, optsJSON, opts.SetByUserID, t, opts.UserID, opts.RoomID)
  440. return err
  441. }