You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

347 lines
11 KiB

8 years ago
8 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "time"
  7. "github.com/matrix-org/go-neb/api"
  8. "github.com/matrix-org/go-neb/types"
  9. "maunium.net/go/mautrix/id"
  10. )
  11. // A ServiceDB stores the configuration for the services
  12. type ServiceDB struct {
  13. db *sql.DB
  14. }
  15. // A single global instance of the service DB.
  16. var globalServiceDB Storer
  17. // SetServiceDB sets the global service DB instance.
  18. func SetServiceDB(db Storer) {
  19. globalServiceDB = db
  20. }
  21. // GetServiceDB gets the global service DB instance.
  22. func GetServiceDB() Storer {
  23. return globalServiceDB
  24. }
  25. // Open a SQL database to use as a ServiceDB. This will automatically create
  26. // the necessary database tables if they aren't already present.
  27. func Open(databaseType, databaseURL string) (serviceDB *ServiceDB, err error) {
  28. db, err := sql.Open(databaseType, databaseURL)
  29. if err != nil {
  30. return
  31. }
  32. if _, err = db.Exec(schemaSQL); err != nil {
  33. return
  34. }
  35. if databaseType == "sqlite3" {
  36. // Fix for "database is locked" errors
  37. // https://github.com/mattn/go-sqlite3/issues/274
  38. db.SetMaxOpenConns(1)
  39. }
  40. serviceDB = &ServiceDB{db: db}
  41. return
  42. }
  43. // StoreMatrixClientConfig stores the Matrix client config for a bot service.
  44. // If a config already exists then it will be updated, otherwise a new config
  45. // will be inserted. The previous config is returned.
  46. func (d *ServiceDB) StoreMatrixClientConfig(config api.ClientConfig) (oldConfig api.ClientConfig, err error) {
  47. err = runTransaction(d.db, func(txn *sql.Tx) error {
  48. oldConfig, err = selectMatrixClientConfigTxn(txn, config.UserID)
  49. now := time.Now()
  50. if err == nil {
  51. return updateMatrixClientConfigTxn(txn, now, config)
  52. } else if err == sql.ErrNoRows {
  53. return insertMatrixClientConfigTxn(txn, now, config)
  54. } else {
  55. return err
  56. }
  57. })
  58. return
  59. }
  60. // LoadMatrixClientConfigs loads all Matrix client configs from the database.
  61. func (d *ServiceDB) LoadMatrixClientConfigs() (configs []api.ClientConfig, err error) {
  62. err = runTransaction(d.db, func(txn *sql.Tx) error {
  63. configs, err = selectMatrixClientConfigsTxn(txn)
  64. return err
  65. })
  66. return
  67. }
  68. // LoadMatrixClientConfig loads a Matrix client config from the database.
  69. // Returns sql.ErrNoRows if the client isn't in the database.
  70. func (d *ServiceDB) LoadMatrixClientConfig(userID id.UserID) (config api.ClientConfig, err error) {
  71. err = runTransaction(d.db, func(txn *sql.Tx) error {
  72. config, err = selectMatrixClientConfigTxn(txn, userID)
  73. return err
  74. })
  75. return
  76. }
  77. // UpdateNextBatch updates the next_batch token for the given user.
  78. func (d *ServiceDB) UpdateNextBatch(userID id.UserID, nextBatch string) (err error) {
  79. err = runTransaction(d.db, func(txn *sql.Tx) error {
  80. return updateNextBatchTxn(txn, userID, nextBatch)
  81. })
  82. return
  83. }
  84. // LoadNextBatch loads the next_batch token for the given user.
  85. func (d *ServiceDB) LoadNextBatch(userID id.UserID) (nextBatch string, err error) {
  86. err = runTransaction(d.db, func(txn *sql.Tx) error {
  87. nextBatch, err = selectNextBatchTxn(txn, userID)
  88. return err
  89. })
  90. return
  91. }
  92. // LoadService loads a service from the database.
  93. // Returns sql.ErrNoRows if the service isn't in the database.
  94. func (d *ServiceDB) LoadService(serviceID string) (service types.Service, err error) {
  95. err = runTransaction(d.db, func(txn *sql.Tx) error {
  96. service, err = selectServiceTxn(txn, serviceID)
  97. return err
  98. })
  99. return
  100. }
  101. // DeleteService deletes the given service from the database.
  102. func (d *ServiceDB) DeleteService(serviceID string) (err error) {
  103. err = runTransaction(d.db, func(txn *sql.Tx) error {
  104. return deleteServiceTxn(txn, serviceID)
  105. })
  106. return
  107. }
  108. // LoadServicesForUser loads all the bot services configured for a given user.
  109. // Returns an empty list if there aren't any services configured.
  110. func (d *ServiceDB) LoadServicesForUser(serviceUserID id.UserID) (services []types.Service, err error) {
  111. err = runTransaction(d.db, func(txn *sql.Tx) error {
  112. services, err = selectServicesForUserTxn(txn, serviceUserID)
  113. if err != nil {
  114. return err
  115. }
  116. return nil
  117. })
  118. return
  119. }
  120. // LoadServicesByType loads all the bot services configured for a given type.
  121. // Returns an empty list if there aren't any services configured.
  122. func (d *ServiceDB) LoadServicesByType(serviceType string) (services []types.Service, err error) {
  123. err = runTransaction(d.db, func(txn *sql.Tx) error {
  124. services, err = selectServicesByTypeTxn(txn, serviceType)
  125. if err != nil {
  126. return err
  127. }
  128. return nil
  129. })
  130. return
  131. }
  132. // StoreService stores a service into the database either by inserting a new
  133. // service or updating an existing service. Returns the old service if there
  134. // was one.
  135. func (d *ServiceDB) StoreService(service types.Service) (oldService types.Service, err error) {
  136. err = runTransaction(d.db, func(txn *sql.Tx) error {
  137. oldService, err = selectServiceTxn(txn, service.ServiceID())
  138. if err == sql.ErrNoRows {
  139. return insertServiceTxn(txn, time.Now(), service)
  140. } else if err != nil {
  141. return err
  142. } else {
  143. return updateServiceTxn(txn, time.Now(), service)
  144. }
  145. })
  146. return
  147. }
  148. // LoadAuthRealm loads an AuthRealm from the database.
  149. // Returns sql.ErrNoRows if the realm isn't in the database.
  150. func (d *ServiceDB) LoadAuthRealm(realmID string) (realm types.AuthRealm, err error) {
  151. err = runTransaction(d.db, func(txn *sql.Tx) error {
  152. realm, err = selectRealmTxn(txn, realmID)
  153. return err
  154. })
  155. return
  156. }
  157. // LoadAuthRealmsByType loads all auth realms with the given type from the database.
  158. // The realms are ordered based on their realm ID.
  159. // Returns an empty list if there are no realms with that type.
  160. func (d *ServiceDB) LoadAuthRealmsByType(realmType string) (realms []types.AuthRealm, err error) {
  161. err = runTransaction(d.db, func(txn *sql.Tx) error {
  162. realms, err = selectRealmsByTypeTxn(txn, realmType)
  163. return err
  164. })
  165. return
  166. }
  167. // StoreAuthRealm stores the given AuthRealm, clobbering based on the realm ID.
  168. // This function updates the time added/updated values. The previous realm, if any, is
  169. // returned.
  170. func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, err error) {
  171. err = runTransaction(d.db, func(txn *sql.Tx) error {
  172. old, err = selectRealmTxn(txn, realm.ID())
  173. if err == sql.ErrNoRows {
  174. return insertRealmTxn(txn, time.Now(), realm)
  175. } else if err != nil {
  176. return err
  177. } else {
  178. return updateRealmTxn(txn, time.Now(), realm)
  179. }
  180. })
  181. return
  182. }
  183. // StoreAuthSession stores the given AuthSession, clobbering based on the tuple of
  184. // user ID and realm ID. This function updates the time added/updated values.
  185. // The previous session, if any, is returned.
  186. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) {
  187. err = runTransaction(d.db, func(txn *sql.Tx) error {
  188. old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID())
  189. if err == sql.ErrNoRows {
  190. return insertAuthSessionTxn(txn, time.Now(), session)
  191. } else if err != nil {
  192. return err
  193. } else {
  194. return updateAuthSessionTxn(txn, time.Now(), session)
  195. }
  196. })
  197. return
  198. }
  199. // RemoveAuthSession removes the auth session for the given user on the given realm.
  200. // No error is returned if the session did not exist in the first place.
  201. func (d *ServiceDB) RemoveAuthSession(realmID string, userID id.UserID) error {
  202. return runTransaction(d.db, func(txn *sql.Tx) error {
  203. return deleteAuthSessionTxn(txn, realmID, userID)
  204. })
  205. }
  206. // LoadAuthSessionByUser loads an AuthSession from the database based on the given
  207. // realm and user ID.
  208. // Returns sql.ErrNoRows if the session isn't in the database.
  209. func (d *ServiceDB) LoadAuthSessionByUser(realmID string, userID id.UserID) (session types.AuthSession, err error) {
  210. err = runTransaction(d.db, func(txn *sql.Tx) error {
  211. session, err = selectAuthSessionByUserTxn(txn, realmID, userID)
  212. return err
  213. })
  214. return
  215. }
  216. // LoadAuthSessionByID loads an AuthSession from the database based on the given
  217. // realm and session ID.
  218. // Returns sql.ErrNoRows if the session isn't in the database.
  219. func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) {
  220. err = runTransaction(d.db, func(txn *sql.Tx) error {
  221. session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID)
  222. return err
  223. })
  224. return
  225. }
  226. // LoadBotOptions loads bot options from the database.
  227. // Returns sql.ErrNoRows if the bot options isn't in the database.
  228. func (d *ServiceDB) LoadBotOptions(userID id.UserID, roomID id.RoomID) (opts types.BotOptions, err error) {
  229. err = runTransaction(d.db, func(txn *sql.Tx) error {
  230. opts, err = selectBotOptionsTxn(txn, userID, roomID)
  231. return err
  232. })
  233. return
  234. }
  235. // StoreBotOptions stores a BotOptions into the database either by inserting a new
  236. // bot options or updating an existing bot options. Returns the old bot options if there
  237. // was one.
  238. func (d *ServiceDB) StoreBotOptions(opts types.BotOptions) (oldOpts types.BotOptions, err error) {
  239. err = runTransaction(d.db, func(txn *sql.Tx) error {
  240. oldOpts, err = selectBotOptionsTxn(txn, opts.UserID, opts.RoomID)
  241. if err == sql.ErrNoRows {
  242. return insertBotOptionsTxn(txn, time.Now(), opts)
  243. } else if err != nil {
  244. return err
  245. } else {
  246. return updateBotOptionsTxn(txn, time.Now(), opts)
  247. }
  248. })
  249. return
  250. }
  251. // InsertFromConfig inserts entries from the config file into the database. This only really
  252. // makes sense for in-memory databases.
  253. func (d *ServiceDB) InsertFromConfig(cfg *api.ConfigFile) error {
  254. // Insert clients
  255. for _, cli := range cfg.Clients {
  256. if _, err := d.StoreMatrixClientConfig(cli); err != nil {
  257. return err
  258. }
  259. }
  260. // Keep a map of realms for inserting sessions
  261. realms := map[string]types.AuthRealm{} // by realm ID
  262. // Insert realms
  263. for _, r := range cfg.Realms {
  264. if err := r.Check(); err != nil {
  265. return err
  266. }
  267. realm, err := types.CreateAuthRealm(r.ID, r.Type, r.Config)
  268. if err != nil {
  269. return err
  270. }
  271. if _, err := d.StoreAuthRealm(realm); err != nil {
  272. return err
  273. }
  274. realms[realm.ID()] = realm
  275. }
  276. // Insert sessions
  277. for _, s := range cfg.Sessions {
  278. if err := s.Check(); err != nil {
  279. return err
  280. }
  281. r := realms[s.RealmID]
  282. if r == nil {
  283. return fmt.Errorf("Session %s specifies an unknown realm ID %s", s.SessionID, s.RealmID)
  284. }
  285. session := r.AuthSession(s.SessionID, s.UserID, s.RealmID)
  286. // dump the raw JSON config directly into the session. This is what
  287. // selectAuthSessionByUserTxn does.
  288. if err := json.Unmarshal(s.Config, session); err != nil {
  289. return err
  290. }
  291. if _, err := d.StoreAuthSession(session); err != nil {
  292. return err
  293. }
  294. }
  295. // Do not insert services yet, they require more work to set up.
  296. return nil
  297. }
  298. func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
  299. txn, err := db.Begin()
  300. if err != nil {
  301. return
  302. }
  303. defer func() {
  304. if r := recover(); r != nil {
  305. txn.Rollback()
  306. panic(r)
  307. } else if err != nil {
  308. txn.Rollback()
  309. } else {
  310. err = txn.Commit()
  311. }
  312. }()
  313. err = fn(txn)
  314. return
  315. }