You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

427 lines
12 KiB

8 years ago
3 years ago
3 years ago
8 years ago
  1. package clients
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "net/http"
  6. "reflect"
  7. "strings"
  8. "sync"
  9. "github.com/matrix-org/go-neb/api"
  10. "github.com/matrix-org/go-neb/database"
  11. "github.com/matrix-org/go-neb/matrix"
  12. "github.com/matrix-org/go-neb/metrics"
  13. "github.com/matrix-org/go-neb/types"
  14. shellwords "github.com/mattn/go-shellwords"
  15. log "github.com/sirupsen/logrus"
  16. "maunium.net/go/mautrix"
  17. mevt "maunium.net/go/mautrix/event"
  18. "maunium.net/go/mautrix/id"
  19. )
  20. // A Clients is a collection of clients used for bot services.
  21. type Clients struct {
  22. db database.Storer
  23. httpClient *http.Client
  24. dbMutex sync.Mutex
  25. mapMutex sync.Mutex
  26. clients map[id.UserID]BotClient
  27. }
  28. // New makes a new collection of matrix clients
  29. func New(db database.Storer, cli *http.Client) *Clients {
  30. clients := &Clients{
  31. db: db,
  32. httpClient: cli,
  33. clients: make(map[id.UserID]BotClient), // user_id => BotClient
  34. }
  35. return clients
  36. }
  37. // Client gets a client for the userID
  38. func (c *Clients) Client(userID id.UserID) (*BotClient, error) {
  39. entry := c.getClient(userID)
  40. if entry.Client != nil {
  41. return &entry, nil
  42. }
  43. entry, err := c.loadClientFromDB(userID)
  44. return &entry, err
  45. }
  46. // Update updates the config for a matrix client
  47. func (c *Clients) Update(config api.ClientConfig) (api.ClientConfig, error) {
  48. _, old, err := c.updateClientInDB(config)
  49. return old.config, err
  50. }
  51. // Start listening on client /sync streams
  52. func (c *Clients) Start() error {
  53. configs, err := c.db.LoadMatrixClientConfigs()
  54. if err != nil {
  55. return err
  56. }
  57. for _, cfg := range configs {
  58. if cfg.Sync {
  59. if _, err := c.Client(cfg.UserID); err != nil {
  60. return err
  61. }
  62. }
  63. }
  64. return nil
  65. }
  66. func (c *Clients) getClient(userID id.UserID) BotClient {
  67. c.mapMutex.Lock()
  68. defer c.mapMutex.Unlock()
  69. return c.clients[userID]
  70. }
  71. func (c *Clients) setClient(client BotClient) {
  72. c.mapMutex.Lock()
  73. defer c.mapMutex.Unlock()
  74. c.clients[client.config.UserID] = client
  75. }
  76. func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) {
  77. c.dbMutex.Lock()
  78. defer c.dbMutex.Unlock()
  79. entry = c.getClient(userID)
  80. if entry.Client != nil {
  81. return
  82. }
  83. if entry.config, err = c.db.LoadMatrixClientConfig(userID); err != nil {
  84. if err == sql.ErrNoRows {
  85. err = fmt.Errorf("client with user ID %s does not exist", userID)
  86. }
  87. return
  88. }
  89. if err = c.initClient(&entry); err != nil {
  90. return
  91. }
  92. c.setClient(entry)
  93. return
  94. }
  95. func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) {
  96. c.dbMutex.Lock()
  97. defer c.dbMutex.Unlock()
  98. old = c.getClient(newConfig.UserID)
  99. if old.Client != nil && reflect.DeepEqual(old.config, newConfig) {
  100. // Already have a client with that config.
  101. new = old
  102. return
  103. }
  104. new.config = newConfig
  105. if err = c.initClient(&new); err != nil {
  106. return
  107. }
  108. // set the new display name if they differ
  109. if old.config.DisplayName != new.config.DisplayName {
  110. if err := new.SetDisplayName(new.config.DisplayName); err != nil {
  111. // whine about it but don't stop: this isn't fatal.
  112. log.WithFields(log.Fields{
  113. log.ErrorKey: err,
  114. "displayname": new.config.DisplayName,
  115. "user_id": new.config.UserID,
  116. }).Error("Failed to set display name")
  117. }
  118. }
  119. if old.config, err = c.db.StoreMatrixClientConfig(new.config); err != nil {
  120. new.StopSync()
  121. return
  122. }
  123. if old.Client != nil {
  124. old.Client.StopSync()
  125. return
  126. }
  127. c.setClient(new)
  128. return
  129. }
  130. func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) {
  131. services, err := c.db.LoadServicesForUser(botClient.UserID)
  132. if err != nil {
  133. log.WithFields(log.Fields{
  134. log.ErrorKey: err,
  135. "room_id": event.RoomID,
  136. "service_user_id": botClient.UserID,
  137. }).Warn("Error loading services")
  138. }
  139. message := event.Content.AsMessage()
  140. body := message.Body
  141. if body == "" {
  142. return
  143. }
  144. // filter m.notice to prevent loops
  145. if message.MsgType == mevt.MsgNotice {
  146. return
  147. }
  148. // replace all smart quotes with their normal counterparts so shellwords can parse it
  149. body = strings.Replace(body, ``, `'`, -1)
  150. body = strings.Replace(body, ``, `'`, -1)
  151. body = strings.Replace(body, ``, `"`, -1)
  152. body = strings.Replace(body, ``, `"`, -1)
  153. var responses []interface{}
  154. for _, service := range services {
  155. if body[0] == '!' { // message is a command
  156. args, err := shellwords.Parse(body[1:])
  157. if err != nil {
  158. args = strings.Split(body[1:], " ")
  159. }
  160. if response := runCommandForService(service.Commands(botClient), event, args); response != nil {
  161. responses = append(responses, response)
  162. }
  163. } else { // message isn't a command, it might need expanding
  164. expansions := runExpansionsForService(service.Expansions(botClient), event, body)
  165. responses = append(responses, expansions...)
  166. }
  167. }
  168. for _, content := range responses {
  169. if _, err := botClient.SendMessageEvent(event.RoomID, mevt.EventMessage, content); err != nil {
  170. log.WithFields(log.Fields{
  171. "room_id": event.RoomID,
  172. "content": content,
  173. "sender": event.Sender,
  174. }).WithError(err).Error("Failed to send command response")
  175. }
  176. }
  177. }
  178. // runCommandForService runs a single command read from a matrix event. Runs
  179. // the matching command with the longest path. Returns the JSON encodable
  180. // content of a single matrix message event to use as a response or nil if no
  181. // response is appropriate.
  182. func runCommandForService(cmds []types.Command, event *mevt.Event, arguments []string) interface{} {
  183. var bestMatch *types.Command
  184. for i, command := range cmds {
  185. matches := command.Matches(arguments)
  186. betterMatch := bestMatch == nil || len(bestMatch.Path) < len(command.Path)
  187. if matches && betterMatch {
  188. bestMatch = &cmds[i]
  189. }
  190. }
  191. if bestMatch == nil {
  192. return nil
  193. }
  194. cmdArgs := arguments[len(bestMatch.Path):]
  195. log.WithFields(log.Fields{
  196. "room_id": event.RoomID,
  197. "user_id": event.Sender,
  198. "command": bestMatch.Path,
  199. }).Info("Executing command")
  200. content, err := bestMatch.Command(event.RoomID, event.Sender, cmdArgs)
  201. if err != nil {
  202. if content != nil {
  203. log.WithFields(log.Fields{
  204. log.ErrorKey: err,
  205. "room_id": event.RoomID,
  206. "user_id": event.Sender,
  207. "command": bestMatch.Path,
  208. "args": cmdArgs,
  209. }).Warn("Command returned both error and content.")
  210. }
  211. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusFailure)
  212. content = mevt.MessageEventContent{
  213. MsgType: mevt.MsgNotice,
  214. Body: err.Error(),
  215. }
  216. } else {
  217. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusSuccess)
  218. }
  219. return content
  220. }
  221. // run the expansions for a matrix event.
  222. func runExpansionsForService(expans []types.Expansion, event *mevt.Event, body string) []interface{} {
  223. var responses []interface{}
  224. for _, expansion := range expans {
  225. matches := map[string]bool{}
  226. for _, matchingGroups := range expansion.Regexp.FindAllStringSubmatch(body, -1) {
  227. matchingText := matchingGroups[0] // first element is always the complete match
  228. if matches[matchingText] {
  229. // Only expand the first occurance of a matching string
  230. continue
  231. }
  232. matches[matchingText] = true
  233. if response := expansion.Expand(event.RoomID, event.Sender, matchingGroups); response != nil {
  234. responses = append(responses, response)
  235. }
  236. }
  237. }
  238. return responses
  239. }
  240. func (c *Clients) onBotOptionsEvent(client *mautrix.Client, event *mevt.Event) {
  241. // see if these options are for us. The state key is the user ID with a leading _
  242. // to get around restrictions in the HS about having user IDs as state keys.
  243. if event.StateKey == nil {
  244. return
  245. }
  246. targetUserID := id.UserID(strings.TrimPrefix(*event.StateKey, "_"))
  247. if targetUserID != client.UserID {
  248. return
  249. }
  250. // these options fully clobber what was there previously.
  251. opts := types.BotOptions{
  252. UserID: client.UserID,
  253. RoomID: event.RoomID,
  254. SetByUserID: event.Sender,
  255. Options: event.Content.Raw,
  256. }
  257. if _, err := c.db.StoreBotOptions(opts); err != nil {
  258. log.WithFields(log.Fields{
  259. log.ErrorKey: err,
  260. "room_id": event.RoomID,
  261. "bot_user_id": client.UserID,
  262. "set_by_user_id": event.Sender,
  263. }).Error("Failed to persist bot options")
  264. }
  265. }
  266. func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) {
  267. if event.StateKey == nil || *event.StateKey != client.UserID.String() {
  268. return // not our member event
  269. }
  270. membership := event.Content.AsMember().Membership
  271. if membership == "invite" {
  272. logger := log.WithFields(log.Fields{
  273. "room_id": event.RoomID,
  274. "service_user_id": client.UserID,
  275. "inviter": event.Sender,
  276. })
  277. logger.Print("Accepting invite from user")
  278. content := struct {
  279. Inviter id.UserID `json:"inviter"`
  280. }{event.Sender}
  281. if _, err := client.JoinRoom(event.RoomID.String(), "", content); err != nil {
  282. logger.WithError(err).Print("Failed to join room")
  283. } else {
  284. logger.Print("Joined room")
  285. }
  286. }
  287. }
  288. func (c *Clients) initClient(botClient *BotClient) error {
  289. config := botClient.config
  290. client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken)
  291. if err != nil {
  292. return err
  293. }
  294. client.Client = c.httpClient
  295. client.DeviceID = config.DeviceID
  296. if client.DeviceID == "" {
  297. log.Warn("Device ID is not set which will result in E2E encryption/decryption not working")
  298. }
  299. botClient.Client = client
  300. botClient.verificationSAS = &sync.Map{}
  301. syncer := client.Syncer.(*mautrix.DefaultSyncer)
  302. syncer.ParseErrorHandler = func(evt *mevt.Event, err error) bool {
  303. // Events of type m.room.bot.options will be flagged as errors as this isn't an event type
  304. // recognised by the Syncer, but we need to process them so the bot can accept options
  305. // through this event (see onBotOptionsEvent)
  306. return evt.Type.Type == "m.room.bot.options"
  307. }
  308. nebStore := &matrix.NEBStore{
  309. InMemoryStore: *mautrix.NewInMemoryStore(),
  310. Database: c.db,
  311. ClientConfig: config,
  312. }
  313. client.Store = nebStore
  314. // TODO: Check that the access token is valid for the userID by peforming
  315. // a request against the server.
  316. if err = botClient.InitOlmMachine(client, nebStore); err != nil {
  317. return err
  318. }
  319. // Register sync callback for maintaining the state store and Olm machine state
  320. botClient.Register(syncer)
  321. syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) {
  322. c.onMessageEvent(botClient, event)
  323. })
  324. syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.StateEventType}, func(_ mautrix.EventSource, event *mevt.Event) {
  325. c.onBotOptionsEvent(botClient.Client, event)
  326. })
  327. if config.AutoJoinRooms {
  328. syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, event *mevt.Event) {
  329. c.onRoomMemberEvent(client, event)
  330. })
  331. }
  332. // When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities.
  333. // If successfully decrypted propagate the decrypted event to the clients.
  334. syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
  335. encContent := evt.Content.AsEncrypted()
  336. decrypted, err := botClient.DecryptMegolmEvent(evt)
  337. if err != nil {
  338. log.WithFields(log.Fields{
  339. "user_id": config.UserID,
  340. "device_id": encContent.DeviceID,
  341. "session_id": encContent.SessionID,
  342. "sender_key": encContent.SenderKey,
  343. }).WithError(err).Error("Failed to decrypt message")
  344. } else {
  345. if decrypted.Type == mevt.EventMessage {
  346. c.onMessageEvent(botClient, decrypted)
  347. }
  348. log.WithFields(log.Fields{
  349. "type": evt.Type,
  350. "sender": evt.Sender,
  351. "room_id": evt.RoomID,
  352. "state_key": evt.StateKey,
  353. }).Trace("Decrypted event successfully")
  354. }
  355. })
  356. // Ignore events before neb's join event.
  357. eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
  358. eventIgnorer.Register(syncer)
  359. log.WithFields(log.Fields{
  360. "user_id": config.UserID,
  361. "device_id": config.DeviceID,
  362. "sync": config.Sync,
  363. "auto_join_rooms": config.AutoJoinRooms,
  364. "since": nebStore.LoadNextBatch(config.UserID),
  365. }).Info("Created new client")
  366. if config.Sync {
  367. go botClient.Sync()
  368. }
  369. return nil
  370. }