You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

435 lines
12 KiB

8 years ago
8 years ago
  1. package clients
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "net/http"
  6. "strings"
  7. "sync"
  8. "github.com/matrix-org/go-neb/api"
  9. "github.com/matrix-org/go-neb/database"
  10. "github.com/matrix-org/go-neb/matrix"
  11. "github.com/matrix-org/go-neb/metrics"
  12. "github.com/matrix-org/go-neb/types"
  13. shellwords "github.com/mattn/go-shellwords"
  14. log "github.com/sirupsen/logrus"
  15. "maunium.net/go/mautrix"
  16. mevt "maunium.net/go/mautrix/event"
  17. "maunium.net/go/mautrix/id"
  18. )
  19. // A Clients is a collection of clients used for bot services.
  20. type Clients struct {
  21. db database.Storer
  22. httpClient *http.Client
  23. dbMutex sync.Mutex
  24. mapMutex sync.Mutex
  25. clients map[id.UserID]BotClient
  26. }
  27. // New makes a new collection of matrix clients
  28. func New(db database.Storer, cli *http.Client) *Clients {
  29. clients := &Clients{
  30. db: db,
  31. httpClient: cli,
  32. clients: make(map[id.UserID]BotClient), // user_id => BotClient
  33. }
  34. return clients
  35. }
  36. // Client gets a client for the userID
  37. func (c *Clients) Client(userID id.UserID) (*BotClient, error) {
  38. entry := c.getClient(userID)
  39. if entry.Client != nil {
  40. return &entry, nil
  41. }
  42. entry, err := c.loadClientFromDB(userID)
  43. return &entry, err
  44. }
  45. // Update updates the config for a matrix client
  46. func (c *Clients) Update(config api.ClientConfig) (api.ClientConfig, error) {
  47. _, old, err := c.updateClientInDB(config)
  48. return old.config, err
  49. }
  50. // Start listening on client /sync streams
  51. func (c *Clients) Start() error {
  52. configs, err := c.db.LoadMatrixClientConfigs()
  53. if err != nil {
  54. return err
  55. }
  56. for _, cfg := range configs {
  57. if cfg.Sync {
  58. if _, err := c.Client(cfg.UserID); err != nil {
  59. return err
  60. }
  61. }
  62. }
  63. return nil
  64. }
  65. func (c *Clients) getClient(userID id.UserID) BotClient {
  66. c.mapMutex.Lock()
  67. defer c.mapMutex.Unlock()
  68. return c.clients[userID]
  69. }
  70. func (c *Clients) setClient(client BotClient) {
  71. c.mapMutex.Lock()
  72. defer c.mapMutex.Unlock()
  73. c.clients[client.config.UserID] = client
  74. }
  75. func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) {
  76. c.dbMutex.Lock()
  77. defer c.dbMutex.Unlock()
  78. entry = c.getClient(userID)
  79. if entry.Client != nil {
  80. return
  81. }
  82. if entry.config, err = c.db.LoadMatrixClientConfig(userID); err != nil {
  83. if err == sql.ErrNoRows {
  84. err = fmt.Errorf("client with user ID %s does not exist", userID)
  85. }
  86. return
  87. }
  88. if err = c.initClient(&entry); err != nil {
  89. return
  90. }
  91. c.setClient(entry)
  92. return
  93. }
  94. func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) {
  95. c.dbMutex.Lock()
  96. defer c.dbMutex.Unlock()
  97. old = c.getClient(newConfig.UserID)
  98. if old.Client != nil && old.config == newConfig {
  99. // Already have a client with that config.
  100. new = old
  101. return
  102. }
  103. new.config = newConfig
  104. if err = c.initClient(&new); err != nil {
  105. return
  106. }
  107. // set the new display name if they differ
  108. if old.config.DisplayName != new.config.DisplayName {
  109. if err := new.SetDisplayName(new.config.DisplayName); err != nil {
  110. // whine about it but don't stop: this isn't fatal.
  111. log.WithFields(log.Fields{
  112. log.ErrorKey: err,
  113. "displayname": new.config.DisplayName,
  114. "user_id": new.config.UserID,
  115. }).Error("Failed to set display name")
  116. }
  117. }
  118. if old.config, err = c.db.StoreMatrixClientConfig(new.config); err != nil {
  119. new.StopSync()
  120. return
  121. }
  122. if old.Client != nil {
  123. old.Client.StopSync()
  124. return
  125. }
  126. c.setClient(new)
  127. return
  128. }
  129. func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) {
  130. services, err := c.db.LoadServicesForUser(botClient.UserID)
  131. if err != nil {
  132. log.WithFields(log.Fields{
  133. log.ErrorKey: err,
  134. "room_id": event.RoomID,
  135. "service_user_id": botClient.UserID,
  136. }).Warn("Error loading services")
  137. }
  138. if err := event.Content.ParseRaw(mevt.EventMessage); err != nil {
  139. return
  140. }
  141. message := event.Content.AsMessage()
  142. body := message.Body
  143. if body == "" {
  144. return
  145. }
  146. // filter m.notice to prevent loops
  147. if message.MsgType == mevt.MsgNotice {
  148. return
  149. }
  150. // replace all smart quotes with their normal counterparts so shellwords can parse it
  151. body = strings.Replace(body, ``, `'`, -1)
  152. body = strings.Replace(body, ``, `'`, -1)
  153. body = strings.Replace(body, ``, `"`, -1)
  154. body = strings.Replace(body, ``, `"`, -1)
  155. var responses []interface{}
  156. for _, service := range services {
  157. if body[0] == '!' { // message is a command
  158. args, err := shellwords.Parse(body[1:])
  159. if err != nil {
  160. args = strings.Split(body[1:], " ")
  161. }
  162. if response := runCommandForService(service.Commands(botClient), event, args); response != nil {
  163. responses = append(responses, response)
  164. }
  165. } else { // message isn't a command, it might need expanding
  166. expansions := runExpansionsForService(service.Expansions(botClient), event, body)
  167. responses = append(responses, expansions...)
  168. }
  169. }
  170. for _, content := range responses {
  171. if _, err := botClient.SendMessageEvent(event.RoomID, mevt.EventMessage, content); err != nil {
  172. log.WithFields(log.Fields{
  173. "room_id": event.RoomID,
  174. "content": content,
  175. "sender": event.Sender,
  176. }).WithError(err).Error("Failed to send command response")
  177. }
  178. }
  179. }
  180. // runCommandForService runs a single command read from a matrix event. Runs
  181. // the matching command with the longest path. Returns the JSON encodable
  182. // content of a single matrix message event to use as a response or nil if no
  183. // response is appropriate.
  184. func runCommandForService(cmds []types.Command, event *mevt.Event, arguments []string) interface{} {
  185. var bestMatch *types.Command
  186. for i, command := range cmds {
  187. matches := command.Matches(arguments)
  188. betterMatch := bestMatch == nil || len(bestMatch.Path) < len(command.Path)
  189. if matches && betterMatch {
  190. bestMatch = &cmds[i]
  191. }
  192. }
  193. if bestMatch == nil {
  194. return nil
  195. }
  196. cmdArgs := arguments[len(bestMatch.Path):]
  197. log.WithFields(log.Fields{
  198. "room_id": event.RoomID,
  199. "user_id": event.Sender,
  200. "command": bestMatch.Path,
  201. }).Info("Executing command")
  202. content, err := bestMatch.Command(event.RoomID, event.Sender, cmdArgs)
  203. if err != nil {
  204. if content != nil {
  205. log.WithFields(log.Fields{
  206. log.ErrorKey: err,
  207. "room_id": event.RoomID,
  208. "user_id": event.Sender,
  209. "command": bestMatch.Path,
  210. "args": cmdArgs,
  211. }).Warn("Command returned both error and content.")
  212. }
  213. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusFailure)
  214. content = mevt.MessageEventContent{
  215. MsgType: mevt.MsgNotice,
  216. Body: err.Error(),
  217. }
  218. } else {
  219. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusSuccess)
  220. }
  221. return content
  222. }
  223. // run the expansions for a matrix event.
  224. func runExpansionsForService(expans []types.Expansion, event *mevt.Event, body string) []interface{} {
  225. var responses []interface{}
  226. for _, expansion := range expans {
  227. matches := map[string]bool{}
  228. for _, matchingGroups := range expansion.Regexp.FindAllStringSubmatch(body, -1) {
  229. matchingText := matchingGroups[0] // first element is always the complete match
  230. if matches[matchingText] {
  231. // Only expand the first occurance of a matching string
  232. continue
  233. }
  234. matches[matchingText] = true
  235. if response := expansion.Expand(event.RoomID, event.Sender, matchingGroups); response != nil {
  236. responses = append(responses, response)
  237. }
  238. }
  239. }
  240. return responses
  241. }
  242. func (c *Clients) onBotOptionsEvent(client *mautrix.Client, event *mevt.Event) {
  243. // see if these options are for us. The state key is the user ID with a leading _
  244. // to get around restrictions in the HS about having user IDs as state keys.
  245. if event.StateKey == nil {
  246. return
  247. }
  248. targetUserID := id.UserID(strings.TrimPrefix(*event.StateKey, "_"))
  249. if targetUserID != client.UserID {
  250. return
  251. }
  252. // these options fully clobber what was there previously.
  253. opts := types.BotOptions{
  254. UserID: client.UserID,
  255. RoomID: event.RoomID,
  256. SetByUserID: event.Sender,
  257. Options: event.Content.Raw,
  258. }
  259. if _, err := c.db.StoreBotOptions(opts); err != nil {
  260. log.WithFields(log.Fields{
  261. log.ErrorKey: err,
  262. "room_id": event.RoomID,
  263. "bot_user_id": client.UserID,
  264. "set_by_user_id": event.Sender,
  265. }).Error("Failed to persist bot options")
  266. }
  267. }
  268. func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) {
  269. if err := event.Content.ParseRaw(mevt.StateMember); err != nil {
  270. return
  271. }
  272. if event.StateKey == nil || *event.StateKey != client.UserID.String() {
  273. return // not our member event
  274. }
  275. membership := event.Content.AsMember().Membership
  276. if membership == "invite" {
  277. logger := log.WithFields(log.Fields{
  278. "room_id": event.RoomID,
  279. "service_user_id": client.UserID,
  280. "inviter": event.Sender,
  281. })
  282. logger.Print("Accepting invite from user")
  283. content := struct {
  284. Inviter id.UserID `json:"inviter"`
  285. }{event.Sender}
  286. if _, err := client.JoinRoom(event.RoomID.String(), "", content); err != nil {
  287. logger.WithError(err).Print("Failed to join room")
  288. } else {
  289. logger.Print("Joined room")
  290. }
  291. }
  292. }
  293. func (c *Clients) initClient(botClient *BotClient) error {
  294. config := botClient.config
  295. client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken)
  296. if err != nil {
  297. return err
  298. }
  299. client.Client = c.httpClient
  300. client.DeviceID = config.DeviceID
  301. if client.DeviceID == "" {
  302. log.Warn("Device ID is not set which will result in E2E encryption/decryption not working")
  303. }
  304. botClient.Client = client
  305. syncer := client.Syncer.(*mautrix.DefaultSyncer)
  306. nebStore := &matrix.NEBStore{
  307. InMemoryStore: *mautrix.NewInMemoryStore(),
  308. Database: c.db,
  309. ClientConfig: config,
  310. }
  311. client.Store = nebStore
  312. // TODO: Check that the access token is valid for the userID by peforming
  313. // a request against the server.
  314. if err = botClient.InitOlmMachine(client, nebStore); err != nil {
  315. return err
  316. }
  317. // Register sync callback for maintaining the state store and Olm machine state
  318. botClient.Register(syncer)
  319. syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) {
  320. c.onMessageEvent(botClient, event)
  321. })
  322. syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(_ mautrix.EventSource, event *mevt.Event) {
  323. c.onBotOptionsEvent(botClient.Client, event)
  324. })
  325. if config.AutoJoinRooms {
  326. syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, event *mevt.Event) {
  327. c.onRoomMemberEvent(client, event)
  328. })
  329. }
  330. // When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities.
  331. // If successfully decrypted propagate the decrypted event to the clients.
  332. syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
  333. if err := evt.Content.ParseRaw(mevt.EventEncrypted); err != nil {
  334. log.WithError(err).Error("Failed to parse encrypted message")
  335. return
  336. }
  337. encContent := evt.Content.AsEncrypted()
  338. decrypted, err := botClient.DecryptMegolmEvent(evt)
  339. if err != nil {
  340. log.WithFields(log.Fields{
  341. "user_id": config.UserID,
  342. "device_id": encContent.DeviceID,
  343. "session_id": encContent.SessionID,
  344. "sender_key": encContent.SenderKey,
  345. }).WithError(err).Error("Failed to decrypt message")
  346. } else {
  347. if decrypted.Type == mevt.EventMessage {
  348. err = decrypted.Content.ParseRaw(mevt.EventMessage)
  349. if err != nil {
  350. log.WithError(err).Error("Could not parse decrypted message event")
  351. } else {
  352. c.onMessageEvent(botClient, decrypted)
  353. }
  354. }
  355. log.WithFields(log.Fields{
  356. "type": evt.Type,
  357. "sender": evt.Sender,
  358. "room_id": evt.RoomID,
  359. "state_key": evt.StateKey,
  360. }).Trace("Decrypted event successfully")
  361. }
  362. })
  363. // Ignore events before neb's join event.
  364. eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
  365. eventIgnorer.Register(syncer)
  366. log.WithFields(log.Fields{
  367. "user_id": config.UserID,
  368. "device_id": config.DeviceID,
  369. "sync": config.Sync,
  370. "auto_join_rooms": config.AutoJoinRooms,
  371. "since": nebStore.LoadNextBatch(config.UserID),
  372. }).Info("Created new client")
  373. if config.Sync {
  374. go botClient.Sync()
  375. }
  376. return nil
  377. }