diff --git a/clients/bot_client.go b/clients/bot_client.go new file mode 100644 index 0000000..0573e29 --- /dev/null +++ b/clients/bot_client.go @@ -0,0 +1,103 @@ +package clients + +import ( + "github.com/matrix-org/go-neb/api" + "github.com/matrix-org/go-neb/matrix" + log "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + mevt "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// BotClient represents one of the bot's sessions, with a specific User and Device ID. +// It can be used for sending messages and retrieving information about the rooms that +// the client has joined. +type BotClient struct { + config api.ClientConfig + client *mautrix.Client + olmMachine *crypto.OlmMachine + stateStore *NebStateStore +} + +// InitOlmMachine initializes a BotClient's internal OlmMachine given a client object and a Neb store, +// which will be used to store room information. +func (botClient *BotClient) InitOlmMachine(client *mautrix.Client, nebStore *matrix.NEBStore) error { + gobStore, err := crypto.NewGobStore("crypto.gob") + if err != nil { + return err + } + + botClient.stateStore = &NebStateStore{&nebStore.InMemoryStore} + olmMachine := crypto.NewOlmMachine(client, CryptoMachineLogger{}, gobStore, botClient.stateStore) + if err = olmMachine.Load(); err != nil { + return nil + } + botClient.olmMachine = olmMachine + + return nil +} + +// Register registers a BotClient's Sync and StateMember event callbacks to update its internal state +// when new events arrive. +func (botClient *BotClient) Register(syncer mautrix.ExtensibleSyncer) { + syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) { + botClient.olmMachine.HandleMemberEvent(evt) + }) + syncer.OnSync(botClient.syncCallback) +} + +func (botClient *BotClient) syncCallback(resp *mautrix.RespSync, since string) bool { + botClient.stateStore.UpdateStateStore(resp) + botClient.olmMachine.ProcessSyncResponse(resp, since) + if err := botClient.olmMachine.CryptoStore.Flush(); err != nil { + log.WithError(err).Error("Could not flush crypto store") + } + return true +} + +// DecryptMegolmEvent attempts to decrypt an incoming m.room.encrypted message using the session information +// already present in the OlmMachine. The corresponding decrypted event is then returned. +// If it fails, usually because the session is not known, an error is returned. +func (botClient *BotClient) DecryptMegolmEvent(evt *mevt.Event) (*mevt.Event, error) { + return botClient.olmMachine.DecryptMegolmEvent(evt) +} + +// SendMessageEvent sends the given content to the given room ID using this BotClient as a message event. +// If the target room has enabled encryption, a megolm session is created if one doesn't already exist +// and the message is sent after being encrypted. +func (botClient *BotClient) SendMessageEvent(content interface{}, roomID id.RoomID) error { + evtType := mevt.EventMessage + olmMachine := botClient.olmMachine + if olmMachine.StateStore.IsEncrypted(roomID) { + // Check if there is already a megolm session + if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(roomID); err != nil { + return err + } else if sess == nil || sess.Expired() || !sess.Shared { + // No error but valid, shared session does not exist + membs, err := botClient.client.JoinedMembers(roomID) + if err != nil { + return err + } + memberIDs := make([]id.UserID, 0, len(membs.Joined)) + for member := range membs.Joined { + memberIDs = append(memberIDs, member) + } + // Share group session with room members + if err = olmMachine.ShareGroupSession(roomID, memberIDs); err != nil { + return err + } + } + msgContent := mevt.Content{Parsed: content} + enc, err := olmMachine.EncryptMegolmEvent(roomID, mevt.EventMessage, msgContent) + if err != nil { + return err + } + content = enc + evtType = mevt.EventEncrypted + } + if _, err := botClient.client.SendMessageEvent(roomID, evtType, content); err != nil { + return err + } + return nil +} diff --git a/clients/clients.go b/clients/clients.go index 31a1560..f562939 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -16,7 +16,6 @@ import ( shellwords "github.com/mattn/go-shellwords" log "github.com/sirupsen/logrus" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" mevt "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -27,7 +26,7 @@ type Clients struct { httpClient *http.Client dbMutex sync.Mutex mapMutex sync.Mutex - clients map[id.UserID]clientEntry + clients map[id.UserID]BotClient } // New makes a new collection of matrix clients @@ -35,7 +34,7 @@ func New(db database.Storer, cli *http.Client) *Clients { clients := &Clients{ db: db, httpClient: cli, - clients: make(map[id.UserID]clientEntry), // user_id => clientEntry + clients: make(map[id.UserID]BotClient), // user_id => BotClient } return clients } @@ -72,25 +71,19 @@ func (c *Clients) Start() error { return nil } -type clientEntry struct { - config api.ClientConfig - client *mautrix.Client - olmMachine *crypto.OlmMachine -} - -func (c *Clients) getClient(userID id.UserID) clientEntry { +func (c *Clients) getClient(userID id.UserID) BotClient { c.mapMutex.Lock() defer c.mapMutex.Unlock() return c.clients[userID] } -func (c *Clients) setClient(client clientEntry) { +func (c *Clients) setClient(client BotClient) { c.mapMutex.Lock() defer c.mapMutex.Unlock() c.clients[client.config.UserID] = client } -func (c *Clients) loadClientFromDB(userID id.UserID) (entry clientEntry, err error) { +func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) { c.dbMutex.Lock() defer c.dbMutex.Unlock() @@ -114,7 +107,7 @@ func (c *Clients) loadClientFromDB(userID id.UserID) (entry clientEntry, err err return } -func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new clientEntry, old clientEntry, err error) { +func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) { c.dbMutex.Lock() defer c.dbMutex.Unlock() @@ -157,13 +150,13 @@ func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new clientEntry, return } -func (c *Clients) onMessageEvent(client *mautrix.Client, event *mevt.Event) { - services, err := c.db.LoadServicesForUser(client.UserID) +func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) { + services, err := c.db.LoadServicesForUser(botClient.client.UserID) if err != nil { log.WithFields(log.Fields{ log.ErrorKey: err, "room_id": event.RoomID, - "service_user_id": client.UserID, + "service_user_id": botClient.client.UserID, }).Warn("Error loading services") } @@ -198,51 +191,22 @@ func (c *Clients) onMessageEvent(client *mautrix.Client, event *mevt.Event) { args = strings.Split(body[1:], " ") } - if response := runCommandForService(service.Commands(client), event, args); response != nil { + if response := runCommandForService(service.Commands(botClient.client), event, args); response != nil { responses = append(responses, response) } } else { // message isn't a command, it might need expanding - expansions := runExpansionsForService(service.Expansions(client), event, body) + expansions := runExpansionsForService(service.Expansions(botClient.client), event, body) responses = append(responses, expansions...) } } for _, content := range responses { - evtType := mevt.EventMessage - curClient := c.clients[client.UserID] - olmMachine := curClient.olmMachine - if olmMachine.StateStore.IsEncrypted(event.RoomID) { - fmt.Println(event.RoomID, "is enc") - if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(event.RoomID); err != nil { - fmt.Println("Error getting outbound", err) - } else if sess == nil { - if membs, err := client.JoinedMembers(event.RoomID); err != nil { - fmt.Println(err) - } else { - memberIDs := make([]id.UserID, 0, len(membs.Joined)) - for member := range membs.Joined { - memberIDs = append(memberIDs, member) - } - if err = olmMachine.ShareGroupSession(event.RoomID, memberIDs); err != nil { - fmt.Println(err) - } - } - } - msgContent := mevt.Content{Parsed: content} - if enc, err := olmMachine.EncryptMegolmEvent(event.RoomID, mevt.EventMessage, msgContent); err != nil { - fmt.Println("error encoding", err) - } else { - content = enc - evtType = mevt.EventEncrypted - } - } - if _, err := client.SendMessageEvent(event.RoomID, evtType, content); err != nil { + if err := botClient.SendMessageEvent(content, event.RoomID); err != nil { log.WithFields(log.Fields{ - log.ErrorKey: err, - "room_id": event.RoomID, - "user_id": event.Sender, - "content": content, - }).Print("Failed to send command response") + "room_id": event.RoomID, + "content": content, + "sender": event.Sender, + }).WithError(err).Error("Failed to send command response") } } } @@ -371,8 +335,8 @@ func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) { } } -func (c *Clients) initClient(clientEntry *clientEntry) error { - config := clientEntry.config +func (c *Clients) initClient(botClient *BotClient) error { + config := botClient.config client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) if err != nil { return err @@ -380,6 +344,8 @@ func (c *Clients) initClient(clientEntry *clientEntry) error { client.Client = c.httpClient client.DeviceID = config.DeviceID + botClient.client = client + syncer := client.Syncer.(*mautrix.DefaultSyncer) nebStore := &matrix.NEBStore{ @@ -392,12 +358,18 @@ func (c *Clients) initClient(clientEntry *clientEntry) error { // TODO: Check that the access token is valid for the userID by peforming // a request against the server. + if err = botClient.InitOlmMachine(client, nebStore); err != nil { + return err + } + + botClient.Register(syncer) + syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) { - c.onMessageEvent(client, event) + c.onMessageEvent(botClient, event) }) syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(_ mautrix.EventSource, event *mevt.Event) { - c.onBotOptionsEvent(client, event) + c.onBotOptionsEvent(botClient.client, event) }) if config.AutoJoinRooms { @@ -406,8 +378,47 @@ func (c *Clients) initClient(clientEntry *clientEntry) error { }) } + // When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities. + // If successfully decrypted propagate the decrypted event to the clients. + syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) { + if err := evt.Content.ParseRaw(mevt.EventEncrypted); err != nil { + log.WithError(err).Error("Failed to parse encrypted message") + return + } + encContent := evt.Content.AsEncrypted() + decrypted, err := botClient.DecryptMegolmEvent(evt) + if err != nil { + log.WithFields(log.Fields{ + "user_id": config.UserID, + "device_id": encContent.DeviceID, + "session_id": encContent.SessionID, + "sender_key": encContent.SenderKey, + }).WithError(err).Error("Failed to decrypt message") + } else { + if decrypted.Type == mevt.EventMessage { + err = decrypted.Content.ParseRaw(mevt.EventMessage) + if err != nil { + log.WithError(err).Error("Could not parse decrypted message event") + } else { + c.onMessageEvent(botClient, decrypted) + } + } + log.WithFields(log.Fields{ + "type": evt.Type, + "sender": evt.Sender, + "room_id": evt.RoomID, + "state_key": evt.StateKey, + }).Trace("Decrypted event successfully") + } + }) + + // Ignore events before neb's join event. + eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID} + eventIgnorer.Register(syncer) + log.WithFields(log.Fields{ "user_id": config.UserID, + "device_id": config.DeviceID, "sync": config.Sync, "auto_join_rooms": config.AutoJoinRooms, "since": nebStore.LoadNextBatch(config.UserID), @@ -430,53 +441,5 @@ func (c *Clients) initClient(clientEntry *clientEntry) error { }() } - clientEntry.client = client - - gobStore, err := crypto.NewGobStore("crypto.gob") - if err != nil { - return err - } - - stateStore := StateStore{&nebStore.InMemoryStore} - olmMachine := crypto.NewOlmMachine(client, CryptoMachineLogger{}, gobStore, &stateStore) - olmMachine.Load() - clientEntry.olmMachine = olmMachine - - syncer.OnSync(stateStore.UpdateStateStore) - // Process sync response with olm machine - syncer.OnSync(func(resp *mautrix.RespSync, since string) bool { - olmMachine.ProcessSyncResponse(resp, since) - if err := olmMachine.CryptoStore.Flush(); err != nil { - fmt.Println("cryptostore flush err", err) - } - return true - }) - - syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) { - olmMachine.HandleMemberEvent(evt) - }) - - syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) { - evt.Content.ParseRaw(mevt.EventEncrypted) - evt, err := olmMachine.DecryptMegolmEvent(evt) - if err != nil { - fmt.Println("decryption err", err) - } else { - if evt.Type == mevt.EventMessage { - err = evt.Content.ParseRaw(mevt.EventMessage) - if err != nil { - fmt.Println("parsing msg err", err) - } else { - c.onMessageEvent(client, evt) - } - } - fmt.Println("decrypted type", evt.Type) - } - }) - - // Ignore events before neb's join event. - eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID} - eventIgnorer.Register(syncer) - return nil } diff --git a/clients/clients_test.go b/clients/clients_test.go index d000f4e..dd04c2a 100644 --- a/clients/clients_test.go +++ b/clients/clients_test.go @@ -76,6 +76,7 @@ func TestCommandParsing(t *testing.T) { clients := New(&store, cli) mxCli, _ := mautrix.NewClient("https://someplace.somewhere", "@service:user", "token") mxCli.Client = cli + botClient := BotClient{client: mxCli} for _, input := range commandParseTests { executedCmdArgs = []string{} @@ -94,7 +95,7 @@ func TestCommandParsing(t *testing.T) { RoomID: "!foo:bar", Content: content, } - clients.onMessageEvent(mxCli, &event) + clients.onMessageEvent(&botClient, &event) if !reflect.DeepEqual(executedCmdArgs, input.expectArgs) { t.Errorf("TestCommandParsing want %s, got %s", input.expectArgs, executedCmdArgs) } diff --git a/clients/crypto_logger.go b/clients/crypto_logger.go new file mode 100644 index 0000000..75026e7 --- /dev/null +++ b/clients/crypto_logger.go @@ -0,0 +1,28 @@ +package clients + +import ( + log "github.com/sirupsen/logrus" +) + +// CryptoMachineLogger wraps around the usual logger, implementing the Logger interface needed by OlmMachine. +type CryptoMachineLogger struct{} + +// Error formats and logs an error message. +func (CryptoMachineLogger) Error(message string, args ...interface{}) { + log.Errorf(message, args...) +} + +// Warn formats and logs a warning message. +func (CryptoMachineLogger) Warn(message string, args ...interface{}) { + log.Warnf(message, args...) +} + +// Debug formats and logs a debug message. +func (CryptoMachineLogger) Debug(message string, args ...interface{}) { + log.Debugf(message, args...) +} + +// Trace formats and logs a trace message. +func (CryptoMachineLogger) Trace(message string, args ...interface{}) { + log.Tracef(message, args...) +} diff --git a/clients/state_store.go b/clients/state_store.go new file mode 100644 index 0000000..0be3a21 --- /dev/null +++ b/clients/state_store.go @@ -0,0 +1,54 @@ +package clients + +import ( + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// NebStateStore implements the StateStore interface for OlmMachine. +// It is used to determine which rooms are encrypted and which rooms are shared with a user. +// The state is updated by /sync responses. +type NebStateStore struct { + Storer *mautrix.InMemoryStore +} + +// IsEncrypted returns whether a room has been encrypted. +func (ss *NebStateStore) IsEncrypted(roomID id.RoomID) bool { + room := ss.Storer.LoadRoom(roomID) + if room == nil { + return false + } + _, ok := room.State[event.StateEncryption] + return ok +} + +// FindSharedRooms returns a list of room IDs that the given user ID is also a member of. +func (ss *NebStateStore) FindSharedRooms(userID id.UserID) []id.RoomID { + sharedRooms := make([]id.RoomID, 0) + for roomID, room := range ss.Storer.Rooms { + if room.GetMembershipState(userID) != event.MembershipLeave { + sharedRooms = append(sharedRooms, roomID) + } + } + return sharedRooms +} + +// UpdateStateStore updates the internal state of NebStateStore from a /sync response. +func (ss *NebStateStore) UpdateStateStore(resp *mautrix.RespSync) { + for roomID, evts := range resp.Rooms.Join { + room := ss.Storer.LoadRoom(roomID) + if room == nil { + room = mautrix.NewRoom(roomID) + ss.Storer.SaveRoom(room) + } + for _, i := range evts.State.Events { + room.UpdateState(i) + } + for _, i := range evts.Timeline.Events { + if i.Type.IsState() { + room.UpdateState(i) + } + } + } +} diff --git a/services/jira/jira.go b/services/jira/jira.go index 9d1b0f6..9b9bf4a 100644 --- a/services/jira/jira.go +++ b/services/jira/jira.go @@ -261,7 +261,7 @@ func (s *Service) Commands(cli *mautrix.Client) []types.Command { // be chosen arbitrarily. func (s *Service) Expansions(cli *mautrix.Client) []types.Expansion { return []types.Expansion{ - types.Expansion{ + { Regexp: issueKeyRegex, Expand: func(roomID id.RoomID, userID id.UserID, issueKeyGroups []string) interface{} { return s.expandIssue(roomID, userID, issueKeyGroups)