You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

353 lines
11 KiB

9 years ago
9 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "time"
  7. "github.com/matrix-org/go-neb/api"
  8. "github.com/matrix-org/go-neb/types"
  9. "maunium.net/go/mautrix/id"
  10. )
  11. // A ServiceDB stores the configuration for the services
  12. type ServiceDB struct {
  13. db *sql.DB
  14. dialect string
  15. }
  16. // A single global instance of the service DB.
  17. var globalServiceDB Storer
  18. // SetServiceDB sets the global service DB instance.
  19. func SetServiceDB(db Storer) {
  20. globalServiceDB = db
  21. }
  22. // GetServiceDB gets the global service DB instance.
  23. func GetServiceDB() Storer {
  24. return globalServiceDB
  25. }
  26. // Open a SQL database to use as a ServiceDB. This will automatically create
  27. // the necessary database tables if they aren't already present.
  28. func Open(databaseType, databaseURL string) (serviceDB *ServiceDB, err error) {
  29. db, err := sql.Open(databaseType, databaseURL)
  30. if err != nil {
  31. return
  32. }
  33. if _, err = db.Exec(schemaSQL); err != nil {
  34. return
  35. }
  36. if databaseType == "sqlite3" {
  37. // Fix for "database is locked" errors
  38. // https://github.com/mattn/go-sqlite3/issues/274
  39. db.SetMaxOpenConns(1)
  40. }
  41. serviceDB = &ServiceDB{db: db, dialect: databaseType}
  42. return
  43. }
  44. // StoreMatrixClientConfig stores the Matrix client config for a bot service.
  45. // If a config already exists then it will be updated, otherwise a new config
  46. // will be inserted. The previous config is returned.
  47. func (d *ServiceDB) StoreMatrixClientConfig(config api.ClientConfig) (oldConfig api.ClientConfig, err error) {
  48. err = runTransaction(d.db, func(txn *sql.Tx) error {
  49. oldConfig, err = selectMatrixClientConfigTxn(txn, config.UserID)
  50. now := time.Now()
  51. if err == nil {
  52. return updateMatrixClientConfigTxn(txn, now, config)
  53. } else if err == sql.ErrNoRows {
  54. return insertMatrixClientConfigTxn(txn, now, config)
  55. } else {
  56. return err
  57. }
  58. })
  59. return
  60. }
  61. // LoadMatrixClientConfigs loads all Matrix client configs from the database.
  62. func (d *ServiceDB) LoadMatrixClientConfigs() (configs []api.ClientConfig, err error) {
  63. err = runTransaction(d.db, func(txn *sql.Tx) error {
  64. configs, err = selectMatrixClientConfigsTxn(txn)
  65. return err
  66. })
  67. return
  68. }
  69. // LoadMatrixClientConfig loads a Matrix client config from the database.
  70. // Returns sql.ErrNoRows if the client isn't in the database.
  71. func (d *ServiceDB) LoadMatrixClientConfig(userID id.UserID) (config api.ClientConfig, err error) {
  72. err = runTransaction(d.db, func(txn *sql.Tx) error {
  73. config, err = selectMatrixClientConfigTxn(txn, userID)
  74. return err
  75. })
  76. return
  77. }
  78. // UpdateNextBatch updates the next_batch token for the given user.
  79. func (d *ServiceDB) UpdateNextBatch(userID id.UserID, nextBatch string) (err error) {
  80. err = runTransaction(d.db, func(txn *sql.Tx) error {
  81. return updateNextBatchTxn(txn, userID, nextBatch)
  82. })
  83. return
  84. }
  85. // LoadNextBatch loads the next_batch token for the given user.
  86. func (d *ServiceDB) LoadNextBatch(userID id.UserID) (nextBatch string, err error) {
  87. err = runTransaction(d.db, func(txn *sql.Tx) error {
  88. nextBatch, err = selectNextBatchTxn(txn, userID)
  89. return err
  90. })
  91. return
  92. }
  93. // LoadService loads a service from the database.
  94. // Returns sql.ErrNoRows if the service isn't in the database.
  95. func (d *ServiceDB) LoadService(serviceID string) (service types.Service, err error) {
  96. err = runTransaction(d.db, func(txn *sql.Tx) error {
  97. service, err = selectServiceTxn(txn, serviceID)
  98. return err
  99. })
  100. return
  101. }
  102. // DeleteService deletes the given service from the database.
  103. func (d *ServiceDB) DeleteService(serviceID string) (err error) {
  104. err = runTransaction(d.db, func(txn *sql.Tx) error {
  105. return deleteServiceTxn(txn, serviceID)
  106. })
  107. return
  108. }
  109. // LoadServicesForUser loads all the bot services configured for a given user.
  110. // Returns an empty list if there aren't any services configured.
  111. func (d *ServiceDB) LoadServicesForUser(serviceUserID id.UserID) (services []types.Service, err error) {
  112. err = runTransaction(d.db, func(txn *sql.Tx) error {
  113. services, err = selectServicesForUserTxn(txn, serviceUserID)
  114. if err != nil {
  115. return err
  116. }
  117. return nil
  118. })
  119. return
  120. }
  121. // LoadServicesByType loads all the bot services configured for a given type.
  122. // Returns an empty list if there aren't any services configured.
  123. func (d *ServiceDB) LoadServicesByType(serviceType string) (services []types.Service, err error) {
  124. err = runTransaction(d.db, func(txn *sql.Tx) error {
  125. services, err = selectServicesByTypeTxn(txn, serviceType)
  126. if err != nil {
  127. return err
  128. }
  129. return nil
  130. })
  131. return
  132. }
  133. // StoreService stores a service into the database either by inserting a new
  134. // service or updating an existing service. Returns the old service if there
  135. // was one.
  136. func (d *ServiceDB) StoreService(service types.Service) (oldService types.Service, err error) {
  137. err = runTransaction(d.db, func(txn *sql.Tx) error {
  138. oldService, err = selectServiceTxn(txn, service.ServiceID())
  139. if err == sql.ErrNoRows {
  140. return insertServiceTxn(txn, time.Now(), service)
  141. } else if err != nil {
  142. return err
  143. } else {
  144. return updateServiceTxn(txn, time.Now(), service)
  145. }
  146. })
  147. return
  148. }
  149. // LoadAuthRealm loads an AuthRealm from the database.
  150. // Returns sql.ErrNoRows if the realm isn't in the database.
  151. func (d *ServiceDB) LoadAuthRealm(realmID string) (realm types.AuthRealm, err error) {
  152. err = runTransaction(d.db, func(txn *sql.Tx) error {
  153. realm, err = selectRealmTxn(txn, realmID)
  154. return err
  155. })
  156. return
  157. }
  158. // LoadAuthRealmsByType loads all auth realms with the given type from the database.
  159. // The realms are ordered based on their realm ID.
  160. // Returns an empty list if there are no realms with that type.
  161. func (d *ServiceDB) LoadAuthRealmsByType(realmType string) (realms []types.AuthRealm, err error) {
  162. err = runTransaction(d.db, func(txn *sql.Tx) error {
  163. realms, err = selectRealmsByTypeTxn(txn, realmType)
  164. return err
  165. })
  166. return
  167. }
  168. // StoreAuthRealm stores the given AuthRealm, clobbering based on the realm ID.
  169. // This function updates the time added/updated values. The previous realm, if any, is
  170. // returned.
  171. func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, err error) {
  172. err = runTransaction(d.db, func(txn *sql.Tx) error {
  173. old, err = selectRealmTxn(txn, realm.ID())
  174. if err == sql.ErrNoRows {
  175. return insertRealmTxn(txn, time.Now(), realm)
  176. } else if err != nil {
  177. return err
  178. } else {
  179. return updateRealmTxn(txn, time.Now(), realm)
  180. }
  181. })
  182. return
  183. }
  184. // StoreAuthSession stores the given AuthSession, clobbering based on the tuple of
  185. // user ID and realm ID. This function updates the time added/updated values.
  186. // The previous session, if any, is returned.
  187. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) {
  188. err = runTransaction(d.db, func(txn *sql.Tx) error {
  189. old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID())
  190. if err == sql.ErrNoRows {
  191. return insertAuthSessionTxn(txn, time.Now(), session)
  192. } else if err != nil {
  193. return err
  194. } else {
  195. return updateAuthSessionTxn(txn, time.Now(), session)
  196. }
  197. })
  198. return
  199. }
  200. // RemoveAuthSession removes the auth session for the given user on the given realm.
  201. // No error is returned if the session did not exist in the first place.
  202. func (d *ServiceDB) RemoveAuthSession(realmID string, userID id.UserID) error {
  203. return runTransaction(d.db, func(txn *sql.Tx) error {
  204. return deleteAuthSessionTxn(txn, realmID, userID)
  205. })
  206. }
  207. // LoadAuthSessionByUser loads an AuthSession from the database based on the given
  208. // realm and user ID.
  209. // Returns sql.ErrNoRows if the session isn't in the database.
  210. func (d *ServiceDB) LoadAuthSessionByUser(realmID string, userID id.UserID) (session types.AuthSession, err error) {
  211. err = runTransaction(d.db, func(txn *sql.Tx) error {
  212. session, err = selectAuthSessionByUserTxn(txn, realmID, userID)
  213. return err
  214. })
  215. return
  216. }
  217. // LoadAuthSessionByID loads an AuthSession from the database based on the given
  218. // realm and session ID.
  219. // Returns sql.ErrNoRows if the session isn't in the database.
  220. func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) {
  221. err = runTransaction(d.db, func(txn *sql.Tx) error {
  222. session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID)
  223. return err
  224. })
  225. return
  226. }
  227. // LoadBotOptions loads bot options from the database.
  228. // Returns sql.ErrNoRows if the bot options isn't in the database.
  229. func (d *ServiceDB) LoadBotOptions(userID id.UserID, roomID id.RoomID) (opts types.BotOptions, err error) {
  230. err = runTransaction(d.db, func(txn *sql.Tx) error {
  231. opts, err = selectBotOptionsTxn(txn, userID, roomID)
  232. return err
  233. })
  234. return
  235. }
  236. // StoreBotOptions stores a BotOptions into the database either by inserting a new
  237. // bot options or updating an existing bot options. Returns the old bot options if there
  238. // was one.
  239. func (d *ServiceDB) StoreBotOptions(opts types.BotOptions) (oldOpts types.BotOptions, err error) {
  240. err = runTransaction(d.db, func(txn *sql.Tx) error {
  241. oldOpts, err = selectBotOptionsTxn(txn, opts.UserID, opts.RoomID)
  242. if err == sql.ErrNoRows {
  243. return insertBotOptionsTxn(txn, time.Now(), opts)
  244. } else if err != nil {
  245. return err
  246. } else {
  247. return updateBotOptionsTxn(txn, time.Now(), opts)
  248. }
  249. })
  250. return
  251. }
  252. // InsertFromConfig inserts entries from the config file into the database. This only really
  253. // makes sense for in-memory databases.
  254. func (d *ServiceDB) InsertFromConfig(cfg *api.ConfigFile) error {
  255. // Insert clients
  256. for _, cli := range cfg.Clients {
  257. if _, err := d.StoreMatrixClientConfig(cli); err != nil {
  258. return err
  259. }
  260. }
  261. // Keep a map of realms for inserting sessions
  262. realms := map[string]types.AuthRealm{} // by realm ID
  263. // Insert realms
  264. for _, r := range cfg.Realms {
  265. if err := r.Check(); err != nil {
  266. return err
  267. }
  268. realm, err := types.CreateAuthRealm(r.ID, r.Type, r.Config)
  269. if err != nil {
  270. return err
  271. }
  272. if _, err := d.StoreAuthRealm(realm); err != nil {
  273. return err
  274. }
  275. realms[realm.ID()] = realm
  276. }
  277. // Insert sessions
  278. for _, s := range cfg.Sessions {
  279. if err := s.Check(); err != nil {
  280. return err
  281. }
  282. r := realms[s.RealmID]
  283. if r == nil {
  284. return fmt.Errorf("Session %s specifies an unknown realm ID %s", s.SessionID, s.RealmID)
  285. }
  286. session := r.AuthSession(s.SessionID, s.UserID, s.RealmID)
  287. // dump the raw JSON config directly into the session. This is what
  288. // selectAuthSessionByUserTxn does.
  289. if err := json.Unmarshal(s.Config, session); err != nil {
  290. return err
  291. }
  292. if _, err := d.StoreAuthSession(session); err != nil {
  293. return err
  294. }
  295. }
  296. // Do not insert services yet, they require more work to set up.
  297. return nil
  298. }
  299. // GetSQLDb retrieves the SQL database instance of a ServiceDB and the dialect it uses (sqlite3 or postgres).
  300. func (d *ServiceDB) GetSQLDb() (*sql.DB, string) {
  301. return d.db, d.dialect
  302. }
  303. func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
  304. txn, err := db.Begin()
  305. if err != nil {
  306. return
  307. }
  308. defer func() {
  309. if r := recover(); r != nil {
  310. txn.Rollback()
  311. panic(r)
  312. } else if err != nil {
  313. txn.Rollback()
  314. } else {
  315. err = txn.Commit()
  316. }
  317. }()
  318. err = fn(txn)
  319. return
  320. }