You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

431 lines
12 KiB

8 years ago
8 years ago
  1. package clients
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "net/http"
  6. "reflect"
  7. "strings"
  8. "sync"
  9. "github.com/matrix-org/go-neb/api"
  10. "github.com/matrix-org/go-neb/database"
  11. "github.com/matrix-org/go-neb/matrix"
  12. "github.com/matrix-org/go-neb/metrics"
  13. "github.com/matrix-org/go-neb/types"
  14. shellwords "github.com/mattn/go-shellwords"
  15. log "github.com/sirupsen/logrus"
  16. "maunium.net/go/mautrix"
  17. mevt "maunium.net/go/mautrix/event"
  18. "maunium.net/go/mautrix/id"
  19. )
  20. // A Clients is a collection of clients used for bot services.
  21. type Clients struct {
  22. db database.Storer
  23. httpClient *http.Client
  24. dbMutex sync.Mutex
  25. mapMutex sync.Mutex
  26. clients map[id.UserID]BotClient
  27. }
  28. // New makes a new collection of matrix clients
  29. func New(db database.Storer, cli *http.Client) *Clients {
  30. clients := &Clients{
  31. db: db,
  32. httpClient: cli,
  33. clients: make(map[id.UserID]BotClient), // user_id => BotClient
  34. }
  35. return clients
  36. }
  37. // Client gets a client for the userID
  38. func (c *Clients) Client(userID id.UserID) (*BotClient, error) {
  39. entry := c.getClient(userID)
  40. if entry.Client != nil {
  41. return &entry, nil
  42. }
  43. entry, err := c.loadClientFromDB(userID)
  44. return &entry, err
  45. }
  46. // Update updates the config for a matrix client
  47. func (c *Clients) Update(config api.ClientConfig) (api.ClientConfig, error) {
  48. _, old, err := c.updateClientInDB(config)
  49. return old.config, err
  50. }
  51. type MautrixLogger struct{}
  52. func (d *MautrixLogger) Debugfln(message string, args ...interface{}) {
  53. log.Debugf(message, args...)
  54. }
  55. func (d *MautrixLogger) Warnfln(message string, args ...interface{}) {
  56. log.Warnf(message, args...)
  57. }
  58. // Start listening on client /sync streams
  59. func (c *Clients) Start() error {
  60. configs, err := c.db.LoadMatrixClientConfigs()
  61. if err != nil {
  62. return err
  63. }
  64. for _, cfg := range configs {
  65. if cfg.Sync {
  66. if _, err := c.Client(cfg.UserID); err != nil {
  67. return err
  68. }
  69. }
  70. }
  71. return nil
  72. }
  73. func (c *Clients) getClient(userID id.UserID) BotClient {
  74. c.mapMutex.Lock()
  75. defer c.mapMutex.Unlock()
  76. return c.clients[userID]
  77. }
  78. func (c *Clients) setClient(client BotClient) {
  79. c.mapMutex.Lock()
  80. defer c.mapMutex.Unlock()
  81. c.clients[client.config.UserID] = client
  82. }
  83. func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) {
  84. c.dbMutex.Lock()
  85. defer c.dbMutex.Unlock()
  86. entry = c.getClient(userID)
  87. if entry.Client != nil {
  88. return
  89. }
  90. if entry.config, err = c.db.LoadMatrixClientConfig(userID); err != nil {
  91. if err == sql.ErrNoRows {
  92. err = fmt.Errorf("client with user ID %s does not exist", userID)
  93. }
  94. return
  95. }
  96. if err = c.initClient(&entry); err != nil {
  97. return
  98. }
  99. c.setClient(entry)
  100. return
  101. }
  102. func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) {
  103. c.dbMutex.Lock()
  104. defer c.dbMutex.Unlock()
  105. old = c.getClient(newConfig.UserID)
  106. if old.Client != nil && reflect.DeepEqual(old.config, newConfig) {
  107. // Already have a client with that config.
  108. new = old
  109. return
  110. }
  111. new.config = newConfig
  112. if err = c.initClient(&new); err != nil {
  113. return
  114. }
  115. // set the new display name if they differ
  116. if old.config.DisplayName != new.config.DisplayName {
  117. if err := new.SetDisplayName(new.config.DisplayName); err != nil {
  118. // whine about it but don't stop: this isn't fatal.
  119. log.WithFields(log.Fields{
  120. log.ErrorKey: err,
  121. "displayname": new.config.DisplayName,
  122. "user_id": new.config.UserID,
  123. }).Error("Failed to set display name")
  124. }
  125. }
  126. if old.config, err = c.db.StoreMatrixClientConfig(new.config); err != nil {
  127. new.StopSync()
  128. return
  129. }
  130. if old.Client != nil {
  131. old.Client.StopSync()
  132. return
  133. }
  134. c.setClient(new)
  135. return
  136. }
  137. func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) {
  138. services, err := c.db.LoadServicesForUser(botClient.UserID)
  139. if err != nil {
  140. log.WithFields(log.Fields{
  141. log.ErrorKey: err,
  142. "room_id": event.RoomID,
  143. "service_user_id": botClient.UserID,
  144. }).Warn("Error loading services")
  145. }
  146. message := event.Content.AsMessage()
  147. body := message.Body
  148. if body == "" {
  149. return
  150. }
  151. // filter m.notice to prevent loops
  152. if message.MsgType == mevt.MsgNotice {
  153. return
  154. }
  155. // replace all smart quotes with their normal counterparts so shellwords can parse it
  156. body = strings.Replace(body, ``, `'`, -1)
  157. body = strings.Replace(body, ``, `'`, -1)
  158. body = strings.Replace(body, ``, `"`, -1)
  159. body = strings.Replace(body, ``, `"`, -1)
  160. var responses []interface{}
  161. for _, service := range services {
  162. if body[0] == '!' { // message is a command
  163. args, err := shellwords.Parse(body[1:])
  164. if err != nil {
  165. args = strings.Split(body[1:], " ")
  166. }
  167. if response := runCommandForService(service.Commands(botClient), event, args); response != nil {
  168. responses = append(responses, response)
  169. }
  170. } else { // message isn't a command, it might need expanding
  171. expansions := runExpansionsForService(service.Expansions(botClient), event, body)
  172. responses = append(responses, expansions...)
  173. }
  174. }
  175. for _, content := range responses {
  176. if _, err := botClient.SendMessageEvent(event.RoomID, mevt.EventMessage, content); err != nil {
  177. log.WithFields(log.Fields{
  178. "room_id": event.RoomID,
  179. "content": content,
  180. "sender": event.Sender,
  181. }).WithError(err).Error("Failed to send command response")
  182. }
  183. }
  184. }
  185. // runCommandForService runs a single command read from a matrix event. Runs
  186. // the matching command with the longest path. Returns the JSON encodable
  187. // content of a single matrix message event to use as a response or nil if no
  188. // response is appropriate.
  189. func runCommandForService(cmds []types.Command, event *mevt.Event, arguments []string) interface{} {
  190. var bestMatch *types.Command
  191. for i, command := range cmds {
  192. matches := command.Matches(arguments)
  193. betterMatch := bestMatch == nil || len(bestMatch.Path) < len(command.Path)
  194. if matches && betterMatch {
  195. bestMatch = &cmds[i]
  196. }
  197. }
  198. if bestMatch == nil {
  199. return nil
  200. }
  201. cmdArgs := arguments[len(bestMatch.Path):]
  202. log.WithFields(log.Fields{
  203. "room_id": event.RoomID,
  204. "user_id": event.Sender,
  205. "command": bestMatch.Path,
  206. }).Info("Executing command")
  207. content, err := bestMatch.Command(event.RoomID, event.Sender, cmdArgs)
  208. if err != nil {
  209. if content != nil {
  210. log.WithFields(log.Fields{
  211. log.ErrorKey: err,
  212. "room_id": event.RoomID,
  213. "user_id": event.Sender,
  214. "command": bestMatch.Path,
  215. "args": cmdArgs,
  216. }).Warn("Command returned both error and content.")
  217. }
  218. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusFailure)
  219. content = mevt.MessageEventContent{
  220. MsgType: mevt.MsgNotice,
  221. Body: err.Error(),
  222. }
  223. } else {
  224. metrics.IncrementCommand(bestMatch.Path[0], metrics.StatusSuccess)
  225. }
  226. return content
  227. }
  228. // run the expansions for a matrix event.
  229. func runExpansionsForService(expans []types.Expansion, event *mevt.Event, body string) []interface{} {
  230. var responses []interface{}
  231. for _, expansion := range expans {
  232. matches := map[string]bool{}
  233. for _, matchingGroups := range expansion.Regexp.FindAllStringSubmatch(body, -1) {
  234. matchingText := matchingGroups[0] // first element is always the complete match
  235. if matches[matchingText] {
  236. // Only expand the first occurance of a matching string
  237. continue
  238. }
  239. matches[matchingText] = true
  240. if response := expansion.Expand(event.RoomID, event.Sender, matchingGroups); response != nil {
  241. responses = append(responses, response)
  242. }
  243. }
  244. }
  245. return responses
  246. }
  247. func (c *Clients) onBotOptionsEvent(client *mautrix.Client, event *mevt.Event) {
  248. // see if these options are for us. The state key is the user ID with a leading _
  249. // to get around restrictions in the HS about having user IDs as state keys.
  250. if event.StateKey == nil {
  251. return
  252. }
  253. targetUserID := id.UserID(strings.TrimPrefix(*event.StateKey, "_"))
  254. if targetUserID != client.UserID {
  255. return
  256. }
  257. // these options fully clobber what was there previously.
  258. opts := types.BotOptions{
  259. UserID: client.UserID,
  260. RoomID: event.RoomID,
  261. SetByUserID: event.Sender,
  262. Options: event.Content.Raw,
  263. }
  264. if _, err := c.db.StoreBotOptions(opts); err != nil {
  265. log.WithFields(log.Fields{
  266. log.ErrorKey: err,
  267. "room_id": event.RoomID,
  268. "bot_user_id": client.UserID,
  269. "set_by_user_id": event.Sender,
  270. }).Error("Failed to persist bot options")
  271. }
  272. }
  273. func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) {
  274. if event.StateKey == nil || *event.StateKey != client.UserID.String() {
  275. return // not our member event
  276. }
  277. membership := event.Content.AsMember().Membership
  278. if membership == "invite" {
  279. logger := log.WithFields(log.Fields{
  280. "room_id": event.RoomID,
  281. "service_user_id": client.UserID,
  282. "inviter": event.Sender,
  283. })
  284. logger.Print("Accepting invite from user")
  285. content := struct {
  286. Inviter id.UserID `json:"inviter"`
  287. }{event.Sender}
  288. if _, err := client.JoinRoom(event.RoomID.String(), "", content); err != nil {
  289. logger.WithError(err).Print("Failed to join room")
  290. } else {
  291. logger.Print("Joined room")
  292. }
  293. }
  294. }
  295. func (c *Clients) initClient(botClient *BotClient) error {
  296. config := botClient.config
  297. client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken)
  298. if err != nil {
  299. return err
  300. }
  301. client.Logger = &MautrixLogger{}
  302. client.Client = c.httpClient
  303. client.DeviceID = config.DeviceID
  304. if client.DeviceID == "" {
  305. log.Warn("Device ID is not set which will result in E2E encryption/decryption not working")
  306. }
  307. botClient.Client = client
  308. botClient.verificationSAS = &sync.Map{}
  309. syncer := client.Syncer.(*mautrix.DefaultSyncer)
  310. nebStore := &matrix.NEBStore{
  311. InMemoryStore: *mautrix.NewInMemoryStore(),
  312. Database: c.db,
  313. ClientConfig: config,
  314. }
  315. client.Store = nebStore
  316. // TODO: Check that the access token is valid for the userID by peforming
  317. // a request against the server.
  318. if err = botClient.InitOlmMachine(client, nebStore); err != nil {
  319. return err
  320. }
  321. // Register sync callback for maintaining the state store and Olm machine state
  322. botClient.Register(syncer)
  323. syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) {
  324. c.onMessageEvent(botClient, event)
  325. })
  326. syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(_ mautrix.EventSource, event *mevt.Event) {
  327. c.onBotOptionsEvent(botClient.Client, event)
  328. })
  329. if config.AutoJoinRooms {
  330. syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, event *mevt.Event) {
  331. c.onRoomMemberEvent(client, event)
  332. })
  333. }
  334. // When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities.
  335. // If successfully decrypted propagate the decrypted event to the clients.
  336. syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
  337. encContent := evt.Content.AsEncrypted()
  338. decrypted, err := botClient.DecryptMegolmEvent(evt)
  339. if err != nil {
  340. log.WithFields(log.Fields{
  341. "user_id": config.UserID,
  342. "device_id": encContent.DeviceID,
  343. "session_id": encContent.SessionID,
  344. "sender_key": encContent.SenderKey,
  345. }).WithError(err).Error("Failed to decrypt message")
  346. } else {
  347. if decrypted.Type == mevt.EventMessage {
  348. c.onMessageEvent(botClient, decrypted)
  349. }
  350. log.WithFields(log.Fields{
  351. "type": evt.Type,
  352. "sender": evt.Sender,
  353. "room_id": evt.RoomID,
  354. "state_key": evt.StateKey,
  355. }).Trace("Decrypted event successfully")
  356. }
  357. })
  358. // Ignore events before neb's join event.
  359. eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
  360. eventIgnorer.Register(syncer)
  361. log.WithFields(log.Fields{
  362. "user_id": config.UserID,
  363. "device_id": config.DeviceID,
  364. "sync": config.Sync,
  365. "auto_join_rooms": config.AutoJoinRooms,
  366. "since": nebStore.LoadNextBatch(config.UserID),
  367. }).Info("Created new client")
  368. if config.Sync {
  369. go botClient.Sync()
  370. }
  371. return nil
  372. }