You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

437 lines
12 KiB

3 years ago
3 years ago
3 years ago
8 years ago
3 years ago
8 years ago
  1. package clients
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "reflect"
  8. "strings"
  9. "sync"
  10. "github.com/matrix-org/go-neb/api"
  11. "github.com/matrix-org/go-neb/database"
  12. "github.com/matrix-org/go-neb/matrix"
  13. "github.com/matrix-org/go-neb/metrics"
  14. "github.com/matrix-org/go-neb/types"
  15. shellwords "github.com/mattn/go-shellwords"
  16. log "github.com/sirupsen/logrus"
  17. "maunium.net/go/mautrix"
  18. mevt "maunium.net/go/mautrix/event"
  19. "maunium.net/go/mautrix/id"
  20. )
  21. // A Clients is a collection of clients used for bot services.
  22. type Clients struct {
  23. db database.Storer
  24. httpClient *http.Client
  25. dbMutex sync.Mutex
  26. mapMutex sync.Mutex
  27. clients map[id.UserID]BotClient
  28. }
  29. // New makes a new collection of matrix clients
  30. func New(db database.Storer, cli *http.Client) *Clients {
  31. clients := &Clients{
  32. db: db,
  33. httpClient: cli,
  34. clients: make(map[id.UserID]BotClient), // user_id => BotClient
  35. }
  36. return clients
  37. }
  38. // Client gets a client for the userID
  39. func (c *Clients) Client(userID id.UserID) (*BotClient, error) {
  40. entry := c.getClient(userID)
  41. if entry.Client != nil {
  42. return &entry, nil
  43. }
  44. entry, err := c.loadClientFromDB(userID)
  45. return &entry, err
  46. }
  47. // Update updates the config for a matrix client
  48. func (c *Clients) Update(config api.ClientConfig) (api.ClientConfig, error) {
  49. _, old, err := c.updateClientInDB(config)
  50. return old.config, err
  51. }
  52. // Start listening on client /sync streams
  53. func (c *Clients) Start() error {
  54. configs, err := c.db.LoadMatrixClientConfigs()
  55. if err != nil {
  56. return err
  57. }
  58. for _, cfg := range configs {
  59. if cfg.Sync {
  60. if _, err := c.Client(cfg.UserID); err != nil {
  61. return err
  62. }
  63. }
  64. }
  65. return nil
  66. }
  67. func (c *Clients) getClient(userID id.UserID) BotClient {
  68. c.mapMutex.Lock()
  69. defer c.mapMutex.Unlock()
  70. return c.clients[userID]
  71. }
  72. func (c *Clients) setClient(client BotClient) {
  73. c.mapMutex.Lock()
  74. defer c.mapMutex.Unlock()
  75. c.clients[client.config.UserID] = client
  76. }
  77. func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) {
  78. c.dbMutex.Lock()
  79. defer c.dbMutex.Unlock()
  80. entry = c.getClient(userID)
  81. if entry.Client != nil {
  82. return
  83. }
  84. if entry.config, err = c.db.LoadMatrixClientConfig(userID); err != nil {
  85. if err == sql.ErrNoRows {
  86. err = fmt.Errorf("client with user ID %s does not exist", userID)
  87. }
  88. return
  89. }
  90. if err = c.initClient(&entry); err != nil {
  91. return
  92. }
  93. c.setClient(entry)
  94. return
  95. }
  96. func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) {
  97. c.dbMutex.Lock()
  98. defer c.dbMutex.Unlock()
  99. old = c.getClient(newConfig.UserID)
  100. if old.Client != nil && reflect.DeepEqual(old.config, newConfig) {
  101. // Already have a client with that config.
  102. new = old
  103. return
  104. }
  105. new.config = newConfig
  106. if err = c.initClient(&new); err != nil {
  107. return
  108. }
  109. // set the new display name if they differ
  110. if old.config.DisplayName != new.config.DisplayName {
  111. if err := new.SetDisplayName(new.config.DisplayName); err != nil {
  112. // whine about it but don't stop: this isn't fatal.
  113. log.WithFields(log.Fields{
  114. log.ErrorKey: err,
  115. "displayname": new.config.DisplayName,
  116. "user_id": new.config.UserID,
  117. }).Error("Failed to set display name")
  118. }
  119. }
  120. if old.config, err = c.db.StoreMatrixClientConfig(new.config); err != nil {
  121. new.StopSync()
  122. return
  123. }
  124. if old.Client != nil {
  125. old.Client.StopSync()
  126. return
  127. }
  128. c.setClient(new)
  129. return
  130. }
  131. func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) {
  132. services, err := c.db.LoadServicesForUser(botClient.UserID)
  133. if err != nil {
  134. log.WithFields(log.Fields{
  135. log.ErrorKey: err,
  136. "room_id": event.RoomID,
  137. "service_user_id": botClient.UserID,
  138. }).Warn("Error loading services")
  139. }
  140. message := event.Content.AsMessage()
  141. body := message.Body
  142. if body == "" {
  143. return
  144. }
  145. // filter m.notice to prevent loops
  146. if message.MsgType == mevt.MsgNotice {
  147. return
  148. }
  149. // replace all smart quotes with their normal counterparts so shellwords can parse it
  150. body = strings.Replace(body, ``, `'`, -1)
  151. body = strings.Replace(body, ``, `'`, -1)
  152. body = strings.Replace(body, ``, `"`, -1)
  153. body = strings.Replace(body, ``, `"`, -1)
  154. var responses []interface{}
  155. for _, service := range services {
  156. if body[0] == '!' { // message is a command
  157. args, err := shellwords.Parse(body[1:])
  158. if err != nil {
  159. args = strings.Split(body[1:], " ")
  160. }
  161. if response := runCommandForService(service.Commands(botClient), event, args); response != nil {
  162. responses = append(responses, response)
  163. }
  164. } else { // message isn't a command, it might need expanding
  165. expansions := runExpansionsForService(service.Expansions(botClient), event, body)
  166. responses = append(responses, expansions...)
  167. }
  168. }
  169. for _, content := range responses {
  170. if _, err := botClient.SendMessageEvent(event.RoomID, mevt.EventMessage, content); err != nil {
  171. log.WithFields(log.Fields{
  172. "room_id": event.RoomID,
  173. "content": content,
  174. "sender": event.Sender,
  175. }).WithError(err).Error("Failed to send command response")
  176. }
  177. }
  178. }
  179. // runCommandForService runs a single command read from a matrix event. Runs
  180. // the matching command with the longest path. Returns the JSON encodable
  181. // content of a single matrix message event to use as a response or nil if no
  182. // response is appropriate.
  183. func runCommandForService(cmds []types.Command, event *mevt.Event, arguments []string) interface{} {
  184. var bestMatch *types.Command
  185. for i, command := range cmds {
  186. matches := command.Matches(arguments)
  187. betterMatch := bestMatch == nil || len(bestMatch.Path) < len(command.Path)
  188. if matches && betterMatch {
  189. bestMatch = &cmds[i]
  190. }
  191. }
  192. if bestMatch == nil {
  193. return nil
  194. }
  195. cmdArgs := arguments[len(bestMatch.Path):]
  196. log.WithFields(log.Fields{
  197. "room_id": event.RoomID,
  198. "user_id": event.Sender,
  199. "command": bestMatch.Path,
  200. }).Info("Executing command")
  201. content, err := bestMatch.Command(event.RoomID, event.Sender, cmdArgs)
  202. if err != nil {
  203. if content != nil {
  204. log.WithFields(log.Fields{
  205. log.ErrorKey: err,
  206. "room_id": event.RoomID,
  207. "user_id": event.Sender,
  208. "command": bestMatch.Path,
  209. "args": cmdArgs,
  210. }).Warn("Command returned both error and content.")
  211. }
  212. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusFailure)
  213. content = mevt.MessageEventContent{
  214. MsgType: mevt.MsgNotice,
  215. Body: err.Error(),
  216. }
  217. } else {
  218. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusSuccess)
  219. }
  220. return content
  221. }
  222. // run the expansions for a matrix event.
  223. func runExpansionsForService(expans []types.Expansion, event *mevt.Event, body string) []interface{} {
  224. var responses []interface{}
  225. for _, expansion := range expans {
  226. matches := map[string]bool{}
  227. for _, matchingGroups := range expansion.Regexp.FindAllStringSubmatch(body, -1) {
  228. matchingText := matchingGroups[0] // first element is always the complete match
  229. if matches[matchingText] {
  230. // Only expand the first occurance of a matching string
  231. continue
  232. }
  233. matches[matchingText] = true
  234. if response := expansion.Expand(event.RoomID, event.Sender, matchingGroups); response != nil {
  235. responses = append(responses, response)
  236. }
  237. }
  238. }
  239. return responses
  240. }
  241. func (c *Clients) onBotOptionsEvent(client *mautrix.Client, event *mevt.Event) {
  242. // see if these options are for us. The state key is the user ID with a leading _
  243. // to get around restrictions in the HS about having user IDs as state keys.
  244. if event.StateKey == nil {
  245. return
  246. }
  247. targetUserID := id.UserID(strings.TrimPrefix(*event.StateKey, "_"))
  248. if targetUserID != client.UserID {
  249. return
  250. }
  251. // these options fully clobber what was there previously.
  252. var options types.BotOptionsContent
  253. if err := json.Unmarshal(event.Content.VeryRaw, &options); err != nil {
  254. log.WithFields(log.Fields{
  255. log.ErrorKey: err,
  256. "room_id": event.RoomID,
  257. "bot_user_id": client.UserID,
  258. "set_by_user_id": event.Sender,
  259. }).Error("Failed to parse bot options")
  260. }
  261. opts := types.BotOptions{
  262. UserID: client.UserID,
  263. RoomID: event.RoomID,
  264. SetByUserID: event.Sender,
  265. Options: options,
  266. }
  267. if _, err := c.db.StoreBotOptions(opts); err != nil {
  268. log.WithFields(log.Fields{
  269. log.ErrorKey: err,
  270. "room_id": event.RoomID,
  271. "bot_user_id": client.UserID,
  272. "set_by_user_id": event.Sender,
  273. }).Error("Failed to persist bot options")
  274. }
  275. }
  276. func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) {
  277. if event.StateKey == nil || *event.StateKey != client.UserID.String() {
  278. return // not our member event
  279. }
  280. membership := event.Content.AsMember().Membership
  281. if membership == "invite" {
  282. logger := log.WithFields(log.Fields{
  283. "room_id": event.RoomID,
  284. "service_user_id": client.UserID,
  285. "inviter": event.Sender,
  286. })
  287. logger.Print("Accepting invite from user")
  288. content := struct {
  289. Inviter id.UserID `json:"inviter"`
  290. }{event.Sender}
  291. if _, err := client.JoinRoom(event.RoomID.String(), "", content); err != nil {
  292. logger.WithError(err).Print("Failed to join room")
  293. } else {
  294. logger.Print("Joined room")
  295. }
  296. }
  297. }
  298. func (c *Clients) initClient(botClient *BotClient) error {
  299. config := botClient.config
  300. client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken)
  301. if err != nil {
  302. return err
  303. }
  304. client.Client = c.httpClient
  305. client.DeviceID = config.DeviceID
  306. if client.DeviceID == "" {
  307. log.Warn("Device ID is not set which will result in E2E encryption/decryption not working")
  308. }
  309. botClient.Client = client
  310. botClient.verificationSAS = &sync.Map{}
  311. syncer := client.Syncer.(*mautrix.DefaultSyncer)
  312. // Add m.room.bot.options to mautrix's TypeMap so that it parses it as a valid event
  313. var StateBotOptions = mevt.Type{Type: "m.room.bot.options", Class: mevt.StateEventType}
  314. mevt.TypeMap[StateBotOptions] = reflect.TypeOf(&types.BotOptionsContent{})
  315. nebStore := &matrix.NEBStore{
  316. InMemoryStore: *mautrix.NewInMemoryStore(),
  317. Database: c.db,
  318. ClientConfig: config,
  319. }
  320. client.Store = nebStore
  321. // TODO: Check that the access token is valid for the userID by peforming
  322. // a request against the server.
  323. if err = botClient.InitOlmMachine(client, nebStore); err != nil {
  324. return err
  325. }
  326. // Register sync callback for maintaining the state store and Olm machine state
  327. botClient.Register(syncer)
  328. syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) {
  329. c.onMessageEvent(botClient, event)
  330. })
  331. syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.StateEventType}, func(_ mautrix.EventSource, event *mevt.Event) {
  332. c.onBotOptionsEvent(botClient.Client, event)
  333. })
  334. if config.AutoJoinRooms {
  335. syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, event *mevt.Event) {
  336. c.onRoomMemberEvent(client, event)
  337. })
  338. }
  339. // When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities.
  340. // If successfully decrypted propagate the decrypted event to the clients.
  341. syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
  342. encContent := evt.Content.AsEncrypted()
  343. decrypted, err := botClient.DecryptMegolmEvent(evt)
  344. if err != nil {
  345. log.WithFields(log.Fields{
  346. "user_id": config.UserID,
  347. "device_id": encContent.DeviceID,
  348. "session_id": encContent.SessionID,
  349. "sender_key": encContent.SenderKey,
  350. }).WithError(err).Error("Failed to decrypt message")
  351. } else {
  352. if decrypted.Type == mevt.EventMessage {
  353. c.onMessageEvent(botClient, decrypted)
  354. }
  355. log.WithFields(log.Fields{
  356. "type": evt.Type,
  357. "sender": evt.Sender,
  358. "room_id": evt.RoomID,
  359. "state_key": evt.StateKey,
  360. }).Trace("Decrypted event successfully")
  361. }
  362. })
  363. // Ignore events before neb's join event.
  364. eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
  365. eventIgnorer.Register(syncer)
  366. log.WithFields(log.Fields{
  367. "user_id": config.UserID,
  368. "device_id": config.DeviceID,
  369. "sync": config.Sync,
  370. "auto_join_rooms": config.AutoJoinRooms,
  371. "since": nebStore.LoadNextBatch(config.UserID),
  372. }).Info("Created new client")
  373. if config.Sync {
  374. go botClient.Sync()
  375. }
  376. return nil
  377. }