You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
12 KiB

  1. package clients
  2. import (
  3. "errors"
  4. "regexp"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/matrix-org/go-neb/api"
  9. "github.com/matrix-org/go-neb/database"
  10. "github.com/matrix-org/go-neb/matrix"
  11. log "github.com/sirupsen/logrus"
  12. "golang.org/x/net/context"
  13. "maunium.net/go/mautrix"
  14. "maunium.net/go/mautrix/crypto"
  15. "maunium.net/go/mautrix/event"
  16. mevt "maunium.net/go/mautrix/event"
  17. "maunium.net/go/mautrix/id"
  18. )
  19. // maximumVerifications is the number of maximum ongoing SAS verifications at a time.
  20. // After this limit we start ignoring verification requests.
  21. const maximumVerifications = 100
  22. // BotClient represents one of the bot's sessions, with a specific User and Device ID.
  23. // It can be used for sending messages and retrieving information about the rooms that
  24. // the client has joined.
  25. type BotClient struct {
  26. *mautrix.Client
  27. config api.ClientConfig
  28. olmMachine *crypto.OlmMachine
  29. stateStore *NebStateStore
  30. verificationSAS *sync.Map
  31. ongoingVerificationCount int32
  32. }
  33. // InitOlmMachine initializes a BotClient's internal OlmMachine given a client object and a Neb store,
  34. // which will be used to store room information.
  35. func (botClient *BotClient) InitOlmMachine(client *mautrix.Client, nebStore *matrix.NEBStore) (err error) {
  36. var cryptoStore crypto.Store
  37. cryptoLogger := CryptoMachineLogger{}
  38. if sdb, ok := database.GetServiceDB().(*database.ServiceDB); ok {
  39. // Create an SQL crypto store based on the ServiceDB used
  40. db, dialect := sdb.GetSQLDb()
  41. accountID := botClient.config.UserID.String() + "-" + client.DeviceID.String()
  42. sqlCryptoStore := crypto.NewSQLCryptoStore(db, dialect, accountID, client.DeviceID, []byte(client.DeviceID.String()+"pickle"), cryptoLogger)
  43. // Try to create the tables if they are missing
  44. if err = sqlCryptoStore.CreateTables(); err != nil {
  45. return
  46. }
  47. cryptoStore = sqlCryptoStore
  48. cryptoLogger.Debug("Using SQL backend as the crypto store")
  49. } else {
  50. deviceID := client.DeviceID.String()
  51. if deviceID == "" {
  52. deviceID = "_empty_device_id"
  53. }
  54. cryptoStore, err = crypto.NewGobStore(deviceID + ".gob")
  55. if err != nil {
  56. return
  57. }
  58. cryptoLogger.Debug("Using gob storage as the crypto store")
  59. }
  60. botClient.stateStore = &NebStateStore{&nebStore.InMemoryStore}
  61. olmMachine := crypto.NewOlmMachine(client, cryptoLogger, cryptoStore, botClient.stateStore)
  62. regexes := make([]*regexp.Regexp, 0, len(botClient.config.AcceptVerificationFromUsers))
  63. for _, userRegex := range botClient.config.AcceptVerificationFromUsers {
  64. regex, err := regexp.Compile(userRegex)
  65. if err != nil {
  66. cryptoLogger.Error("Error compiling regex %v: %v", userRegex, err)
  67. } else {
  68. regexes = append(regexes, regex)
  69. }
  70. }
  71. olmMachine.AcceptVerificationFrom = func(_ string, otherDevice *crypto.DeviceIdentity) (crypto.VerificationRequestResponse, crypto.VerificationHooks) {
  72. for _, regex := range regexes {
  73. if regex.MatchString(otherDevice.UserID.String()) {
  74. if atomic.LoadInt32(&botClient.ongoingVerificationCount) >= maximumVerifications {
  75. cryptoLogger.Trace("User ID %v matches regex %v but we are currently at maximum verifications, ignoring...", otherDevice.UserID, regex)
  76. return crypto.IgnoreRequest, botClient
  77. }
  78. cryptoLogger.Trace("User ID %v matches regex %v, accepting SAS request", otherDevice.UserID, regex)
  79. atomic.AddInt32(&botClient.ongoingVerificationCount, 1)
  80. return crypto.AcceptRequest, botClient
  81. }
  82. }
  83. cryptoLogger.Trace("User ID %v does not match any regex, rejecting SAS request", otherDevice.UserID)
  84. return crypto.RejectRequest, botClient
  85. }
  86. if err = olmMachine.Load(); err != nil {
  87. return
  88. }
  89. botClient.olmMachine = olmMachine
  90. return nil
  91. }
  92. // Register registers a BotClient's Sync and StateMember event callbacks to update its internal state
  93. // when new events arrive.
  94. func (botClient *BotClient) Register(syncer mautrix.ExtensibleSyncer) {
  95. syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) {
  96. botClient.olmMachine.HandleMemberEvent(evt)
  97. })
  98. syncer.OnSync(botClient.syncCallback)
  99. }
  100. func (botClient *BotClient) syncCallback(resp *mautrix.RespSync, since string) bool {
  101. botClient.stateStore.UpdateStateStore(resp)
  102. botClient.olmMachine.ProcessSyncResponse(resp, since)
  103. if err := botClient.olmMachine.CryptoStore.Flush(); err != nil {
  104. log.WithError(err).Error("Could not flush crypto store")
  105. }
  106. return true
  107. }
  108. // DecryptMegolmEvent attempts to decrypt an incoming m.room.encrypted message using the session information
  109. // already present in the OlmMachine. The corresponding decrypted event is then returned.
  110. // If it fails, usually because the session is not known, an error is returned.
  111. func (botClient *BotClient) DecryptMegolmEvent(evt *mevt.Event) (*mevt.Event, error) {
  112. return botClient.olmMachine.DecryptMegolmEvent(evt)
  113. }
  114. // SendMessageEvent sends the given content to the given room ID using this BotClient as a message event.
  115. // If the target room has enabled encryption, a megolm session is created if one doesn't already exist
  116. // and the message is sent after being encrypted.
  117. func (botClient *BotClient) SendMessageEvent(roomID id.RoomID, evtType mevt.Type, content interface{},
  118. extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
  119. olmMachine := botClient.olmMachine
  120. if olmMachine.StateStore.IsEncrypted(roomID) {
  121. // Check if there is already a megolm session
  122. if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(roomID); err != nil {
  123. return nil, err
  124. } else if sess == nil || sess.Expired() || !sess.Shared {
  125. // No error but valid, shared session does not exist
  126. memberIDs, err := botClient.stateStore.GetJoinedMembers(roomID)
  127. if err != nil {
  128. return nil, err
  129. }
  130. // Share group session with room members
  131. if err = olmMachine.ShareGroupSession(roomID, memberIDs); err != nil {
  132. return nil, err
  133. }
  134. }
  135. enc, err := olmMachine.EncryptMegolmEvent(roomID, mevt.EventMessage, content)
  136. if err != nil {
  137. return nil, err
  138. }
  139. content = enc
  140. evtType = mevt.EventEncrypted
  141. }
  142. return botClient.Client.SendMessageEvent(roomID, evtType, content, extra...)
  143. }
  144. // Sync loops to keep syncing the client with the homeserver by calling the /sync endpoint.
  145. func (botClient *BotClient) Sync() {
  146. // Get the state store up to date
  147. resp, err := botClient.SyncRequest(30000, "", "", true, mevt.PresenceOnline)
  148. if err != nil {
  149. log.WithError(err).Error("Error performing initial sync")
  150. return
  151. }
  152. botClient.stateStore.UpdateStateStore(resp)
  153. for {
  154. if e := botClient.Client.Sync(); e != nil {
  155. log.WithFields(log.Fields{
  156. log.ErrorKey: e,
  157. "user_id": botClient.config.UserID,
  158. }).Error("Fatal Sync() error")
  159. time.Sleep(10 * time.Second)
  160. } else {
  161. log.WithField("user_id", botClient.config.UserID).Info("Stopping Sync()")
  162. return
  163. }
  164. }
  165. }
  166. // VerifySASMatch returns whether the received SAS matches the SAS that the bot generated.
  167. // It retrieves the SAS of the other device from the bot client's SAS sync map, where it was stored by the `SubmitDecimalSAS` function.
  168. func (botClient *BotClient) VerifySASMatch(otherDevice *crypto.DeviceIdentity, sas crypto.SASData) bool {
  169. log.WithFields(log.Fields{
  170. "otherUser": otherDevice.UserID,
  171. "otherDevice": otherDevice.DeviceID,
  172. }).Infof("Waiting for SAS")
  173. if sas.Type() != event.SASDecimal {
  174. log.Warnf("Unsupported SAS type: %v", sas.Type())
  175. return false
  176. }
  177. key := otherDevice.UserID.String() + ":" + otherDevice.DeviceID.String()
  178. sasChan, loaded := botClient.verificationSAS.LoadOrStore(key, make(chan crypto.DecimalSASData))
  179. if !loaded {
  180. // if we created the chan, delete it after the timeout duration
  181. defer botClient.verificationSAS.Delete(key)
  182. }
  183. select {
  184. case otherSAS := <-sasChan.(chan crypto.DecimalSASData):
  185. ourSAS := sas.(crypto.DecimalSASData)
  186. log.WithFields(log.Fields{
  187. "otherUser": otherDevice.UserID,
  188. "otherDevice": otherDevice.DeviceID,
  189. }).Warnf("Our SAS: %v, Received SAS: %v, Match: %v", ourSAS, otherSAS, ourSAS == otherSAS)
  190. return ourSAS == otherSAS
  191. case <-time.After(botClient.olmMachine.DefaultSASTimeout):
  192. log.Warnf("Timed out while waiting for SAS from device %v", otherDevice.DeviceID)
  193. }
  194. return false
  195. }
  196. // SubmitDecimalSAS stores the received decimal SAS from another device to compare to the local one.
  197. // It stores the SAS in the bot client's SAS sync map to be retrieved from the `VerifySASMatch` function.
  198. func (botClient *BotClient) SubmitDecimalSAS(otherUser id.UserID, otherDevice id.DeviceID, sas crypto.DecimalSASData) {
  199. key := otherUser.String() + ":" + otherDevice.String()
  200. sasChan, loaded := botClient.verificationSAS.LoadOrStore(key, make(chan crypto.DecimalSASData))
  201. go func() {
  202. if !loaded {
  203. // if we created the chan, delete it after the timeout duration
  204. defer botClient.verificationSAS.Delete(key)
  205. }
  206. // insert to channel in goroutine to avoid blocking if we are not expecting a SAS for this user/device right now
  207. select {
  208. case sasChan.(chan crypto.DecimalSASData) <- crypto.DecimalSASData(sas):
  209. case <-time.After(botClient.olmMachine.DefaultSASTimeout):
  210. log.Warnf("Timed out while trying to send SAS for device %v", otherDevice)
  211. }
  212. }()
  213. }
  214. // VerificationMethods returns the supported SAS verification methods.
  215. // As a bot we only support decimal as it's easier to understand.
  216. func (botClient *BotClient) VerificationMethods() []crypto.VerificationMethod {
  217. return []crypto.VerificationMethod{
  218. crypto.VerificationMethodDecimal{},
  219. }
  220. }
  221. // OnCancel is called when a SAS verification is canceled.
  222. func (botClient *BotClient) OnCancel(cancelledByUs bool, reason string, reasonCode event.VerificationCancelCode) {
  223. atomic.AddInt32(&botClient.ongoingVerificationCount, -1)
  224. log.Tracef("Verification cancelled with reason: %v", reason)
  225. }
  226. // OnSuccess is called when a SAS verification is successful.
  227. func (botClient *BotClient) OnSuccess() {
  228. atomic.AddInt32(&botClient.ongoingVerificationCount, -1)
  229. log.Trace("Verification was successful")
  230. }
  231. // InvalidateRoomSession invalidates the outbound group session for the given room.
  232. func (botClient *BotClient) InvalidateRoomSession(roomID id.RoomID) (id.SessionID, error) {
  233. outbound, err := botClient.olmMachine.CryptoStore.GetOutboundGroupSession(roomID)
  234. if err != nil {
  235. return "", err
  236. }
  237. if outbound == nil {
  238. return "", errors.New("No group session found for this room")
  239. }
  240. return outbound.ID(), botClient.olmMachine.CryptoStore.RemoveOutboundGroupSession(roomID)
  241. }
  242. // StartSASVerification starts a new SAS verification with the given user and device ID and returns the transaction ID if successful.
  243. func (botClient *BotClient) StartSASVerification(userID id.UserID, deviceID id.DeviceID) (string, error) {
  244. device, err := botClient.olmMachine.GetOrFetchDevice(userID, deviceID)
  245. if err != nil {
  246. return "", err
  247. }
  248. return botClient.olmMachine.NewSimpleSASVerificationWith(device, botClient)
  249. }
  250. // SendRoomKeyRequest sends a room key request to another device.
  251. func (botClient *BotClient) SendRoomKeyRequest(userID id.UserID, deviceID id.DeviceID, roomID id.RoomID,
  252. senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) (chan bool, error) {
  253. ctx, _ := context.WithTimeout(context.Background(), timeout)
  254. return botClient.olmMachine.RequestRoomKey(ctx, userID, deviceID, roomID, senderKey, sessionID)
  255. }
  256. // ForwardRoomKeyToDevice sends a room key to another device.
  257. func (botClient *BotClient) ForwardRoomKeyToDevice(userID id.UserID, deviceID id.DeviceID, roomID id.RoomID, senderKey id.SenderKey,
  258. sessionID id.SessionID) error {
  259. device, err := botClient.olmMachine.GetOrFetchDevice(userID, deviceID)
  260. if err != nil {
  261. return err
  262. }
  263. igs, err := botClient.olmMachine.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
  264. if err != nil {
  265. return err
  266. } else if igs == nil {
  267. return errors.New("Group session not found")
  268. }
  269. exportedKey, err := igs.Internal.Export(igs.Internal.FirstKnownIndex())
  270. if err != nil {
  271. return err
  272. }
  273. forwardedRoomKey := event.Content{
  274. Parsed: &event.ForwardedRoomKeyEventContent{
  275. RoomKeyEventContent: event.RoomKeyEventContent{
  276. Algorithm: id.AlgorithmMegolmV1,
  277. RoomID: igs.RoomID,
  278. SessionID: igs.ID(),
  279. SessionKey: exportedKey,
  280. },
  281. SenderKey: senderKey,
  282. ForwardingKeyChain: igs.ForwardingChains,
  283. SenderClaimedKey: igs.SigningKey,
  284. },
  285. }
  286. return botClient.olmMachine.SendEncryptedToDevice(device, forwardedRoomKey)
  287. }