You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
11 KiB

9 years ago
9 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/matrix-org/go-neb/api"
  7. "github.com/matrix-org/go-neb/types"
  8. "time"
  9. )
  10. // A ServiceDB stores the configuration for the services
  11. type ServiceDB struct {
  12. db *sql.DB
  13. }
  14. // A single global instance of the service DB.
  15. var globalServiceDB Storer
  16. // SetServiceDB sets the global service DB instance.
  17. func SetServiceDB(db Storer) {
  18. globalServiceDB = db
  19. }
  20. // GetServiceDB gets the global service DB instance.
  21. func GetServiceDB() Storer {
  22. return globalServiceDB
  23. }
  24. // Open a SQL database to use as a ServiceDB. This will automatically create
  25. // the necessary database tables if they aren't already present.
  26. func Open(databaseType, databaseURL string) (serviceDB *ServiceDB, err error) {
  27. db, err := sql.Open(databaseType, databaseURL)
  28. if err != nil {
  29. return
  30. }
  31. if _, err = db.Exec(schemaSQL); err != nil {
  32. return
  33. }
  34. if databaseType == "sqlite3" {
  35. // Fix for "database is locked" errors
  36. // https://github.com/mattn/go-sqlite3/issues/274
  37. db.SetMaxOpenConns(1)
  38. }
  39. serviceDB = &ServiceDB{db: db}
  40. return
  41. }
  42. // StoreMatrixClientConfig stores the Matrix client config for a bot service.
  43. // If a config already exists then it will be updated, otherwise a new config
  44. // will be inserted. The previous config is returned.
  45. func (d *ServiceDB) StoreMatrixClientConfig(config api.ClientConfig) (oldConfig api.ClientConfig, err error) {
  46. err = runTransaction(d.db, func(txn *sql.Tx) error {
  47. oldConfig, err = selectMatrixClientConfigTxn(txn, config.UserID)
  48. now := time.Now()
  49. if err == nil {
  50. return updateMatrixClientConfigTxn(txn, now, config)
  51. } else if err == sql.ErrNoRows {
  52. return insertMatrixClientConfigTxn(txn, now, config)
  53. } else {
  54. return err
  55. }
  56. })
  57. return
  58. }
  59. // LoadMatrixClientConfigs loads all Matrix client configs from the database.
  60. func (d *ServiceDB) LoadMatrixClientConfigs() (configs []api.ClientConfig, err error) {
  61. err = runTransaction(d.db, func(txn *sql.Tx) error {
  62. configs, err = selectMatrixClientConfigsTxn(txn)
  63. return err
  64. })
  65. return
  66. }
  67. // LoadMatrixClientConfig loads a Matrix client config from the database.
  68. // Returns sql.ErrNoRows if the client isn't in the database.
  69. func (d *ServiceDB) LoadMatrixClientConfig(userID string) (config api.ClientConfig, err error) {
  70. err = runTransaction(d.db, func(txn *sql.Tx) error {
  71. config, err = selectMatrixClientConfigTxn(txn, userID)
  72. return err
  73. })
  74. return
  75. }
  76. // UpdateNextBatch updates the next_batch token for the given user.
  77. func (d *ServiceDB) UpdateNextBatch(userID, nextBatch string) (err error) {
  78. err = runTransaction(d.db, func(txn *sql.Tx) error {
  79. return updateNextBatchTxn(txn, userID, nextBatch)
  80. })
  81. return
  82. }
  83. // LoadNextBatch loads the next_batch token for the given user.
  84. func (d *ServiceDB) LoadNextBatch(userID string) (nextBatch string, err error) {
  85. err = runTransaction(d.db, func(txn *sql.Tx) error {
  86. nextBatch, err = selectNextBatchTxn(txn, userID)
  87. return err
  88. })
  89. return
  90. }
  91. // LoadService loads a service from the database.
  92. // Returns sql.ErrNoRows if the service isn't in the database.
  93. func (d *ServiceDB) LoadService(serviceID string) (service types.Service, err error) {
  94. err = runTransaction(d.db, func(txn *sql.Tx) error {
  95. service, err = selectServiceTxn(txn, serviceID)
  96. return err
  97. })
  98. return
  99. }
  100. // DeleteService deletes the given service from the database.
  101. func (d *ServiceDB) DeleteService(serviceID string) (err error) {
  102. err = runTransaction(d.db, func(txn *sql.Tx) error {
  103. return deleteServiceTxn(txn, serviceID)
  104. })
  105. return
  106. }
  107. // LoadServicesForUser loads all the bot services configured for a given user.
  108. // Returns an empty list if there aren't any services configured.
  109. func (d *ServiceDB) LoadServicesForUser(serviceUserID string) (services []types.Service, err error) {
  110. err = runTransaction(d.db, func(txn *sql.Tx) error {
  111. services, err = selectServicesForUserTxn(txn, serviceUserID)
  112. if err != nil {
  113. return err
  114. }
  115. return nil
  116. })
  117. return
  118. }
  119. // LoadServicesByType loads all the bot services configured for a given type.
  120. // Returns an empty list if there aren't any services configured.
  121. func (d *ServiceDB) LoadServicesByType(serviceType string) (services []types.Service, err error) {
  122. err = runTransaction(d.db, func(txn *sql.Tx) error {
  123. services, err = selectServicesByTypeTxn(txn, serviceType)
  124. if err != nil {
  125. return err
  126. }
  127. return nil
  128. })
  129. return
  130. }
  131. // StoreService stores a service into the database either by inserting a new
  132. // service or updating an existing service. Returns the old service if there
  133. // was one.
  134. func (d *ServiceDB) StoreService(service types.Service) (oldService types.Service, err error) {
  135. err = runTransaction(d.db, func(txn *sql.Tx) error {
  136. oldService, err = selectServiceTxn(txn, service.ServiceID())
  137. if err == sql.ErrNoRows {
  138. return insertServiceTxn(txn, time.Now(), service)
  139. } else if err != nil {
  140. return err
  141. } else {
  142. return updateServiceTxn(txn, time.Now(), service)
  143. }
  144. })
  145. return
  146. }
  147. // LoadAuthRealm loads an AuthRealm from the database.
  148. // Returns sql.ErrNoRows if the realm isn't in the database.
  149. func (d *ServiceDB) LoadAuthRealm(realmID string) (realm types.AuthRealm, err error) {
  150. err = runTransaction(d.db, func(txn *sql.Tx) error {
  151. realm, err = selectRealmTxn(txn, realmID)
  152. return err
  153. })
  154. return
  155. }
  156. // LoadAuthRealmsByType loads all auth realms with the given type from the database.
  157. // The realms are ordered based on their realm ID.
  158. // Returns an empty list if there are no realms with that type.
  159. func (d *ServiceDB) LoadAuthRealmsByType(realmType string) (realms []types.AuthRealm, err error) {
  160. err = runTransaction(d.db, func(txn *sql.Tx) error {
  161. realms, err = selectRealmsByTypeTxn(txn, realmType)
  162. return err
  163. })
  164. return
  165. }
  166. // StoreAuthRealm stores the given AuthRealm, clobbering based on the realm ID.
  167. // This function updates the time added/updated values. The previous realm, if any, is
  168. // returned.
  169. func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, err error) {
  170. err = runTransaction(d.db, func(txn *sql.Tx) error {
  171. old, err = selectRealmTxn(txn, realm.ID())
  172. if err == sql.ErrNoRows {
  173. return insertRealmTxn(txn, time.Now(), realm)
  174. } else if err != nil {
  175. return err
  176. } else {
  177. return updateRealmTxn(txn, time.Now(), realm)
  178. }
  179. })
  180. return
  181. }
  182. // StoreAuthSession stores the given AuthSession, clobbering based on the tuple of
  183. // user ID and realm ID. This function updates the time added/updated values.
  184. // The previous session, if any, is returned.
  185. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) {
  186. err = runTransaction(d.db, func(txn *sql.Tx) error {
  187. old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID())
  188. if err == sql.ErrNoRows {
  189. return insertAuthSessionTxn(txn, time.Now(), session)
  190. } else if err != nil {
  191. return err
  192. } else {
  193. return updateAuthSessionTxn(txn, time.Now(), session)
  194. }
  195. })
  196. return
  197. }
  198. // RemoveAuthSession removes the auth session for the given user on the given realm.
  199. // No error is returned if the session did not exist in the first place.
  200. func (d *ServiceDB) RemoveAuthSession(realmID, userID string) error {
  201. return runTransaction(d.db, func(txn *sql.Tx) error {
  202. return deleteAuthSessionTxn(txn, realmID, userID)
  203. })
  204. }
  205. // LoadAuthSessionByUser loads an AuthSession from the database based on the given
  206. // realm and user ID.
  207. // Returns sql.ErrNoRows if the session isn't in the database.
  208. func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) {
  209. err = runTransaction(d.db, func(txn *sql.Tx) error {
  210. session, err = selectAuthSessionByUserTxn(txn, realmID, userID)
  211. return err
  212. })
  213. return
  214. }
  215. // LoadAuthSessionByID loads an AuthSession from the database based on the given
  216. // realm and session ID.
  217. // Returns sql.ErrNoRows if the session isn't in the database.
  218. func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) {
  219. err = runTransaction(d.db, func(txn *sql.Tx) error {
  220. session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID)
  221. return err
  222. })
  223. return
  224. }
  225. // LoadBotOptions loads bot options from the database.
  226. // Returns sql.ErrNoRows if the bot options isn't in the database.
  227. func (d *ServiceDB) LoadBotOptions(userID, roomID string) (opts types.BotOptions, err error) {
  228. err = runTransaction(d.db, func(txn *sql.Tx) error {
  229. opts, err = selectBotOptionsTxn(txn, userID, roomID)
  230. return err
  231. })
  232. return
  233. }
  234. // StoreBotOptions stores a BotOptions into the database either by inserting a new
  235. // bot options or updating an existing bot options. Returns the old bot options if there
  236. // was one.
  237. func (d *ServiceDB) StoreBotOptions(opts types.BotOptions) (oldOpts types.BotOptions, err error) {
  238. err = runTransaction(d.db, func(txn *sql.Tx) error {
  239. oldOpts, err = selectBotOptionsTxn(txn, opts.UserID, opts.RoomID)
  240. if err == sql.ErrNoRows {
  241. return insertBotOptionsTxn(txn, time.Now(), opts)
  242. } else if err != nil {
  243. return err
  244. } else {
  245. return updateBotOptionsTxn(txn, time.Now(), opts)
  246. }
  247. })
  248. return
  249. }
  250. // InsertFromConfig inserts entries from the config file into the database. This only really
  251. // makes sense for in-memory databases.
  252. func (d *ServiceDB) InsertFromConfig(cfg *api.ConfigFile) error {
  253. // Insert clients
  254. for _, cli := range cfg.Clients {
  255. if _, err := d.StoreMatrixClientConfig(cli); err != nil {
  256. return err
  257. }
  258. }
  259. // Keep a map of realms for inserting sessions
  260. realms := map[string]types.AuthRealm{} // by realm ID
  261. // Insert realms
  262. for _, r := range cfg.Realms {
  263. if err := r.Check(); err != nil {
  264. return err
  265. }
  266. realm, err := types.CreateAuthRealm(r.ID, r.Type, r.Config)
  267. if err != nil {
  268. return err
  269. }
  270. if _, err := d.StoreAuthRealm(realm); err != nil {
  271. return err
  272. }
  273. realms[realm.ID()] = realm
  274. }
  275. // Insert sessions
  276. for _, s := range cfg.Sessions {
  277. if err := s.Check(); err != nil {
  278. return err
  279. }
  280. r := realms[s.RealmID]
  281. if r == nil {
  282. return fmt.Errorf("Session %s specifies an unknown realm ID %s", s.SessionID, s.RealmID)
  283. }
  284. session := r.AuthSession(s.SessionID, s.UserID, s.RealmID)
  285. // dump the raw JSON config directly into the session. This is what
  286. // selectAuthSessionByUserTxn does.
  287. if err := json.Unmarshal(s.Config, session); err != nil {
  288. return err
  289. }
  290. if _, err := d.StoreAuthSession(session); err != nil {
  291. return err
  292. }
  293. }
  294. // Do not insert services yet, they require more work to set up.
  295. return nil
  296. }
  297. func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
  298. txn, err := db.Begin()
  299. if err != nil {
  300. return
  301. }
  302. defer func() {
  303. if r := recover(); r != nil {
  304. txn.Rollback()
  305. panic(r)
  306. } else if err != nil {
  307. txn.Rollback()
  308. } else {
  309. err = txn.Commit()
  310. }
  311. }()
  312. err = fn(txn)
  313. return
  314. }