You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
8.4 KiB

8 years ago
8 years ago
  1. package database
  2. import (
  3. "database/sql"
  4. "github.com/matrix-org/go-neb/types"
  5. "time"
  6. )
  7. // A ServiceDB stores the configuration for the services
  8. type ServiceDB struct {
  9. db *sql.DB
  10. }
  11. // A single global instance of the service DB.
  12. // XXX: I can't think of any way of doing this without one without creating
  13. // cyclical dependencies somewhere -- Kegan
  14. var globalServiceDB *ServiceDB
  15. // SetServiceDB sets the global service DB instance.
  16. func SetServiceDB(db *ServiceDB) {
  17. globalServiceDB = db
  18. }
  19. // GetServiceDB gets the global service DB instance.
  20. func GetServiceDB() *ServiceDB {
  21. return globalServiceDB
  22. }
  23. // Open a SQL database to use as a ServiceDB. This will automatically create
  24. // the necessary database tables if they aren't already present.
  25. func Open(databaseType, databaseURL string) (serviceDB *ServiceDB, err error) {
  26. db, err := sql.Open(databaseType, databaseURL)
  27. if err != nil {
  28. return
  29. }
  30. if _, err = db.Exec(schemaSQL); err != nil {
  31. return
  32. }
  33. serviceDB = &ServiceDB{db: db}
  34. return
  35. }
  36. // StoreMatrixClientConfig stores the Matrix client config for a bot service.
  37. // If a config already exists then it will be updated, otherwise a new config
  38. // will be inserted. The previous config is returned.
  39. func (d *ServiceDB) StoreMatrixClientConfig(config types.ClientConfig) (oldConfig types.ClientConfig, err error) {
  40. err = runTransaction(d.db, func(txn *sql.Tx) error {
  41. oldConfig, err = selectMatrixClientConfigTxn(txn, config.UserID)
  42. now := time.Now()
  43. if err == nil {
  44. return updateMatrixClientConfigTxn(txn, now, config)
  45. } else if err == sql.ErrNoRows {
  46. return insertMatrixClientConfigTxn(txn, now, config)
  47. } else {
  48. return err
  49. }
  50. })
  51. return
  52. }
  53. // LoadMatrixClientConfigs loads all Matrix client configs from the database.
  54. func (d *ServiceDB) LoadMatrixClientConfigs() (configs []types.ClientConfig, err error) {
  55. err = runTransaction(d.db, func(txn *sql.Tx) error {
  56. configs, err = selectMatrixClientConfigsTxn(txn)
  57. return err
  58. })
  59. return
  60. }
  61. // LoadMatrixClientConfig loads a Matrix client config from the database.
  62. // Returns sql.ErrNoRows if the client isn't in the database.
  63. func (d *ServiceDB) LoadMatrixClientConfig(userID string) (config types.ClientConfig, err error) {
  64. err = runTransaction(d.db, func(txn *sql.Tx) error {
  65. config, err = selectMatrixClientConfigTxn(txn, userID)
  66. return err
  67. })
  68. return
  69. }
  70. // UpdateNextBatch updates the next_batch token for the given user.
  71. func (d *ServiceDB) UpdateNextBatch(userID, nextBatch string) (err error) {
  72. err = runTransaction(d.db, func(txn *sql.Tx) error {
  73. return updateNextBatchTxn(txn, userID, nextBatch)
  74. })
  75. return
  76. }
  77. // LoadNextBatch loads the next_batch token for the given user.
  78. func (d *ServiceDB) LoadNextBatch(userID string) (nextBatch string, err error) {
  79. err = runTransaction(d.db, func(txn *sql.Tx) error {
  80. nextBatch, err = selectNextBatchTxn(txn, userID)
  81. return err
  82. })
  83. return
  84. }
  85. // LoadService loads a service from the database.
  86. // Returns sql.ErrNoRows if the service isn't in the database.
  87. func (d *ServiceDB) LoadService(serviceID string) (service types.Service, err error) {
  88. err = runTransaction(d.db, func(txn *sql.Tx) error {
  89. service, err = selectServiceTxn(txn, serviceID)
  90. return err
  91. })
  92. return
  93. }
  94. // DeleteService deletes the given service from the database.
  95. func (d *ServiceDB) DeleteService(serviceID string) (err error) {
  96. err = runTransaction(d.db, func(txn *sql.Tx) error {
  97. return deleteServiceTxn(txn, serviceID)
  98. })
  99. return
  100. }
  101. // LoadServicesForUser loads all the bot services configured for a given user.
  102. // Returns an empty list if there aren't any services configured.
  103. func (d *ServiceDB) LoadServicesForUser(serviceUserID string) (services []types.Service, err error) {
  104. err = runTransaction(d.db, func(txn *sql.Tx) error {
  105. services, err = selectServicesForUserTxn(txn, serviceUserID)
  106. if err != nil {
  107. return err
  108. }
  109. return nil
  110. })
  111. return
  112. }
  113. // StoreService stores a service into the database either by inserting a new
  114. // service or updating an existing service. Returns the old service if there
  115. // was one.
  116. func (d *ServiceDB) StoreService(service types.Service) (oldService types.Service, err error) {
  117. err = runTransaction(d.db, func(txn *sql.Tx) error {
  118. oldService, err = selectServiceTxn(txn, service.ServiceID())
  119. if err == sql.ErrNoRows {
  120. return insertServiceTxn(txn, time.Now(), service)
  121. } else if err != nil {
  122. return err
  123. } else {
  124. return updateServiceTxn(txn, time.Now(), service)
  125. }
  126. })
  127. return
  128. }
  129. // LoadAuthRealm loads an AuthRealm from the database.
  130. // Returns sql.ErrNoRows if the realm isn't in the database.
  131. func (d *ServiceDB) LoadAuthRealm(realmID string) (realm types.AuthRealm, err error) {
  132. err = runTransaction(d.db, func(txn *sql.Tx) error {
  133. realm, err = selectRealmTxn(txn, realmID)
  134. return err
  135. })
  136. return
  137. }
  138. // LoadAuthRealmsByType loads all auth realms with the given type from the database.
  139. // The realms are ordered based on their realm ID.
  140. // Returns an empty list if there are no realms with that type.
  141. func (d *ServiceDB) LoadAuthRealmsByType(realmType string) (realms []types.AuthRealm, err error) {
  142. err = runTransaction(d.db, func(txn *sql.Tx) error {
  143. realms, err = selectRealmsByTypeTxn(txn, realmType)
  144. return err
  145. })
  146. return
  147. }
  148. // StoreAuthRealm stores the given AuthRealm, clobbering based on the realm ID.
  149. // This function updates the time added/updated values. The previous realm, if any, is
  150. // returned.
  151. func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, err error) {
  152. err = runTransaction(d.db, func(txn *sql.Tx) error {
  153. old, err = selectRealmTxn(txn, realm.ID())
  154. if err == sql.ErrNoRows {
  155. return insertRealmTxn(txn, time.Now(), realm)
  156. } else if err != nil {
  157. return err
  158. } else {
  159. return updateRealmTxn(txn, time.Now(), realm)
  160. }
  161. })
  162. return
  163. }
  164. // StoreAuthSession stores the given AuthSession, clobbering based on the tuple of
  165. // user ID and realm ID. This function updates the time added/updated values.
  166. // The previous session, if any, is returned.
  167. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) {
  168. err = runTransaction(d.db, func(txn *sql.Tx) error {
  169. old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID())
  170. if err == sql.ErrNoRows {
  171. return insertAuthSessionTxn(txn, time.Now(), session)
  172. } else if err != nil {
  173. return err
  174. } else {
  175. return updateAuthSessionTxn(txn, time.Now(), session)
  176. }
  177. })
  178. return
  179. }
  180. // LoadAuthSessionByUser loads an AuthSession from the database based on the given
  181. // realm and user ID.
  182. // Returns sql.ErrNoRows if the session isn't in the database.
  183. func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) {
  184. err = runTransaction(d.db, func(txn *sql.Tx) error {
  185. session, err = selectAuthSessionByUserTxn(txn, realmID, userID)
  186. return err
  187. })
  188. return
  189. }
  190. // LoadAuthSessionByID loads an AuthSession from the database based on the given
  191. // realm and session ID.
  192. // Returns sql.ErrNoRows if the session isn't in the database.
  193. func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) {
  194. err = runTransaction(d.db, func(txn *sql.Tx) error {
  195. session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID)
  196. return err
  197. })
  198. return
  199. }
  200. // LoadBotOptions loads bot options from the database.
  201. // Returns sql.ErrNoRows if the bot options isn't in the database.
  202. func (d *ServiceDB) LoadBotOptions(userID, roomID string) (opts types.BotOptions, err error) {
  203. err = runTransaction(d.db, func(txn *sql.Tx) error {
  204. opts, err = selectBotOptionsTxn(txn, userID, roomID)
  205. return err
  206. })
  207. return
  208. }
  209. // StoreBotOptions stores a BotOptions into the database either by inserting a new
  210. // bot options or updating an existing bot options. Returns the old bot options if there
  211. // was one.
  212. func (d *ServiceDB) StoreBotOptions(opts types.BotOptions) (oldOpts types.BotOptions, err error) {
  213. err = runTransaction(d.db, func(txn *sql.Tx) error {
  214. oldOpts, err = selectBotOptionsTxn(txn, opts.UserID, opts.RoomID)
  215. if err == sql.ErrNoRows {
  216. return insertBotOptionsTxn(txn, time.Now(), opts)
  217. } else if err != nil {
  218. return err
  219. } else {
  220. return updateBotOptionsTxn(txn, time.Now(), opts)
  221. }
  222. })
  223. return
  224. }
  225. func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
  226. txn, err := db.Begin()
  227. if err != nil {
  228. return
  229. }
  230. defer func() {
  231. if r := recover(); r != nil {
  232. txn.Rollback()
  233. panic(r)
  234. } else if err != nil {
  235. txn.Rollback()
  236. } else {
  237. err = txn.Commit()
  238. }
  239. }()
  240. err = fn(txn)
  241. return
  242. }