Browse Source

Move some of the client and crypto logic to a new BotClient type

Signed-off-by: Nikos Filippakis <me@nfil.dev>
pull/324/head
Nikos Filippakis 4 years ago
parent
commit
584d674747
No known key found for this signature in database GPG Key ID: 7110E4356101F017
  1. 103
      clients/bot_client.go
  2. 171
      clients/clients.go
  3. 3
      clients/clients_test.go
  4. 28
      clients/crypto_logger.go
  5. 54
      clients/state_store.go
  6. 2
      services/jira/jira.go

103
clients/bot_client.go

@ -0,0 +1,103 @@
package clients
import (
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/matrix"
log "github.com/sirupsen/logrus"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
mevt "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// BotClient represents one of the bot's sessions, with a specific User and Device ID.
// It can be used for sending messages and retrieving information about the rooms that
// the client has joined.
type BotClient struct {
config api.ClientConfig
client *mautrix.Client
olmMachine *crypto.OlmMachine
stateStore *NebStateStore
}
// InitOlmMachine initializes a BotClient's internal OlmMachine given a client object and a Neb store,
// which will be used to store room information.
func (botClient *BotClient) InitOlmMachine(client *mautrix.Client, nebStore *matrix.NEBStore) error {
gobStore, err := crypto.NewGobStore("crypto.gob")
if err != nil {
return err
}
botClient.stateStore = &NebStateStore{&nebStore.InMemoryStore}
olmMachine := crypto.NewOlmMachine(client, CryptoMachineLogger{}, gobStore, botClient.stateStore)
if err = olmMachine.Load(); err != nil {
return nil
}
botClient.olmMachine = olmMachine
return nil
}
// Register registers a BotClient's Sync and StateMember event callbacks to update its internal state
// when new events arrive.
func (botClient *BotClient) Register(syncer mautrix.ExtensibleSyncer) {
syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) {
botClient.olmMachine.HandleMemberEvent(evt)
})
syncer.OnSync(botClient.syncCallback)
}
func (botClient *BotClient) syncCallback(resp *mautrix.RespSync, since string) bool {
botClient.stateStore.UpdateStateStore(resp)
botClient.olmMachine.ProcessSyncResponse(resp, since)
if err := botClient.olmMachine.CryptoStore.Flush(); err != nil {
log.WithError(err).Error("Could not flush crypto store")
}
return true
}
// DecryptMegolmEvent attempts to decrypt an incoming m.room.encrypted message using the session information
// already present in the OlmMachine. The corresponding decrypted event is then returned.
// If it fails, usually because the session is not known, an error is returned.
func (botClient *BotClient) DecryptMegolmEvent(evt *mevt.Event) (*mevt.Event, error) {
return botClient.olmMachine.DecryptMegolmEvent(evt)
}
// SendMessageEvent sends the given content to the given room ID using this BotClient as a message event.
// If the target room has enabled encryption, a megolm session is created if one doesn't already exist
// and the message is sent after being encrypted.
func (botClient *BotClient) SendMessageEvent(content interface{}, roomID id.RoomID) error {
evtType := mevt.EventMessage
olmMachine := botClient.olmMachine
if olmMachine.StateStore.IsEncrypted(roomID) {
// Check if there is already a megolm session
if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(roomID); err != nil {
return err
} else if sess == nil || sess.Expired() || !sess.Shared {
// No error but valid, shared session does not exist
membs, err := botClient.client.JoinedMembers(roomID)
if err != nil {
return err
}
memberIDs := make([]id.UserID, 0, len(membs.Joined))
for member := range membs.Joined {
memberIDs = append(memberIDs, member)
}
// Share group session with room members
if err = olmMachine.ShareGroupSession(roomID, memberIDs); err != nil {
return err
}
}
msgContent := mevt.Content{Parsed: content}
enc, err := olmMachine.EncryptMegolmEvent(roomID, mevt.EventMessage, msgContent)
if err != nil {
return err
}
content = enc
evtType = mevt.EventEncrypted
}
if _, err := botClient.client.SendMessageEvent(roomID, evtType, content); err != nil {
return err
}
return nil
}

171
clients/clients.go

@ -16,7 +16,6 @@ import (
shellwords "github.com/mattn/go-shellwords"
log "github.com/sirupsen/logrus"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
mevt "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -27,7 +26,7 @@ type Clients struct {
httpClient *http.Client
dbMutex sync.Mutex
mapMutex sync.Mutex
clients map[id.UserID]clientEntry
clients map[id.UserID]BotClient
}
// New makes a new collection of matrix clients
@ -35,7 +34,7 @@ func New(db database.Storer, cli *http.Client) *Clients {
clients := &Clients{
db: db,
httpClient: cli,
clients: make(map[id.UserID]clientEntry), // user_id => clientEntry
clients: make(map[id.UserID]BotClient), // user_id => BotClient
}
return clients
}
@ -72,25 +71,19 @@ func (c *Clients) Start() error {
return nil
}
type clientEntry struct {
config api.ClientConfig
client *mautrix.Client
olmMachine *crypto.OlmMachine
}
func (c *Clients) getClient(userID id.UserID) clientEntry {
func (c *Clients) getClient(userID id.UserID) BotClient {
c.mapMutex.Lock()
defer c.mapMutex.Unlock()
return c.clients[userID]
}
func (c *Clients) setClient(client clientEntry) {
func (c *Clients) setClient(client BotClient) {
c.mapMutex.Lock()
defer c.mapMutex.Unlock()
c.clients[client.config.UserID] = client
}
func (c *Clients) loadClientFromDB(userID id.UserID) (entry clientEntry, err error) {
func (c *Clients) loadClientFromDB(userID id.UserID) (entry BotClient, err error) {
c.dbMutex.Lock()
defer c.dbMutex.Unlock()
@ -114,7 +107,7 @@ func (c *Clients) loadClientFromDB(userID id.UserID) (entry clientEntry, err err
return
}
func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new clientEntry, old clientEntry, err error) {
func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new, old BotClient, err error) {
c.dbMutex.Lock()
defer c.dbMutex.Unlock()
@ -157,13 +150,13 @@ func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new clientEntry,
return
}
func (c *Clients) onMessageEvent(client *mautrix.Client, event *mevt.Event) {
services, err := c.db.LoadServicesForUser(client.UserID)
func (c *Clients) onMessageEvent(botClient *BotClient, event *mevt.Event) {
services, err := c.db.LoadServicesForUser(botClient.client.UserID)
if err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
"room_id": event.RoomID,
"service_user_id": client.UserID,
"service_user_id": botClient.client.UserID,
}).Warn("Error loading services")
}
@ -198,51 +191,22 @@ func (c *Clients) onMessageEvent(client *mautrix.Client, event *mevt.Event) {
args = strings.Split(body[1:], " ")
}
if response := runCommandForService(service.Commands(client), event, args); response != nil {
if response := runCommandForService(service.Commands(botClient.client), event, args); response != nil {
responses = append(responses, response)
}
} else { // message isn't a command, it might need expanding
expansions := runExpansionsForService(service.Expansions(client), event, body)
expansions := runExpansionsForService(service.Expansions(botClient.client), event, body)
responses = append(responses, expansions...)
}
}
for _, content := range responses {
evtType := mevt.EventMessage
curClient := c.clients[client.UserID]
olmMachine := curClient.olmMachine
if olmMachine.StateStore.IsEncrypted(event.RoomID) {
fmt.Println(event.RoomID, "is enc")
if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(event.RoomID); err != nil {
fmt.Println("Error getting outbound", err)
} else if sess == nil {
if membs, err := client.JoinedMembers(event.RoomID); err != nil {
fmt.Println(err)
} else {
memberIDs := make([]id.UserID, 0, len(membs.Joined))
for member := range membs.Joined {
memberIDs = append(memberIDs, member)
}
if err = olmMachine.ShareGroupSession(event.RoomID, memberIDs); err != nil {
fmt.Println(err)
}
}
}
msgContent := mevt.Content{Parsed: content}
if enc, err := olmMachine.EncryptMegolmEvent(event.RoomID, mevt.EventMessage, msgContent); err != nil {
fmt.Println("error encoding", err)
} else {
content = enc
evtType = mevt.EventEncrypted
}
}
if _, err := client.SendMessageEvent(event.RoomID, evtType, content); err != nil {
if err := botClient.SendMessageEvent(content, event.RoomID); err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
"room_id": event.RoomID,
"user_id": event.Sender,
"content": content,
}).Print("Failed to send command response")
"room_id": event.RoomID,
"content": content,
"sender": event.Sender,
}).WithError(err).Error("Failed to send command response")
}
}
}
@ -371,8 +335,8 @@ func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) {
}
}
func (c *Clients) initClient(clientEntry *clientEntry) error {
config := clientEntry.config
func (c *Clients) initClient(botClient *BotClient) error {
config := botClient.config
client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken)
if err != nil {
return err
@ -380,6 +344,8 @@ func (c *Clients) initClient(clientEntry *clientEntry) error {
client.Client = c.httpClient
client.DeviceID = config.DeviceID
botClient.client = client
syncer := client.Syncer.(*mautrix.DefaultSyncer)
nebStore := &matrix.NEBStore{
@ -392,12 +358,18 @@ func (c *Clients) initClient(clientEntry *clientEntry) error {
// TODO: Check that the access token is valid for the userID by peforming
// a request against the server.
if err = botClient.InitOlmMachine(client, nebStore); err != nil {
return err
}
botClient.Register(syncer)
syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) {
c.onMessageEvent(client, event)
c.onMessageEvent(botClient, event)
})
syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(_ mautrix.EventSource, event *mevt.Event) {
c.onBotOptionsEvent(client, event)
c.onBotOptionsEvent(botClient.client, event)
})
if config.AutoJoinRooms {
@ -406,8 +378,47 @@ func (c *Clients) initClient(clientEntry *clientEntry) error {
})
}
// When receiving an encrypted event, attempt to decrypt it using the BotClient's capabilities.
// If successfully decrypted propagate the decrypted event to the clients.
syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
if err := evt.Content.ParseRaw(mevt.EventEncrypted); err != nil {
log.WithError(err).Error("Failed to parse encrypted message")
return
}
encContent := evt.Content.AsEncrypted()
decrypted, err := botClient.DecryptMegolmEvent(evt)
if err != nil {
log.WithFields(log.Fields{
"user_id": config.UserID,
"device_id": encContent.DeviceID,
"session_id": encContent.SessionID,
"sender_key": encContent.SenderKey,
}).WithError(err).Error("Failed to decrypt message")
} else {
if decrypted.Type == mevt.EventMessage {
err = decrypted.Content.ParseRaw(mevt.EventMessage)
if err != nil {
log.WithError(err).Error("Could not parse decrypted message event")
} else {
c.onMessageEvent(botClient, decrypted)
}
}
log.WithFields(log.Fields{
"type": evt.Type,
"sender": evt.Sender,
"room_id": evt.RoomID,
"state_key": evt.StateKey,
}).Trace("Decrypted event successfully")
}
})
// Ignore events before neb's join event.
eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
eventIgnorer.Register(syncer)
log.WithFields(log.Fields{
"user_id": config.UserID,
"device_id": config.DeviceID,
"sync": config.Sync,
"auto_join_rooms": config.AutoJoinRooms,
"since": nebStore.LoadNextBatch(config.UserID),
@ -430,53 +441,5 @@ func (c *Clients) initClient(clientEntry *clientEntry) error {
}()
}
clientEntry.client = client
gobStore, err := crypto.NewGobStore("crypto.gob")
if err != nil {
return err
}
stateStore := StateStore{&nebStore.InMemoryStore}
olmMachine := crypto.NewOlmMachine(client, CryptoMachineLogger{}, gobStore, &stateStore)
olmMachine.Load()
clientEntry.olmMachine = olmMachine
syncer.OnSync(stateStore.UpdateStateStore)
// Process sync response with olm machine
syncer.OnSync(func(resp *mautrix.RespSync, since string) bool {
olmMachine.ProcessSyncResponse(resp, since)
if err := olmMachine.CryptoStore.Flush(); err != nil {
fmt.Println("cryptostore flush err", err)
}
return true
})
syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) {
olmMachine.HandleMemberEvent(evt)
})
syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) {
evt.Content.ParseRaw(mevt.EventEncrypted)
evt, err := olmMachine.DecryptMegolmEvent(evt)
if err != nil {
fmt.Println("decryption err", err)
} else {
if evt.Type == mevt.EventMessage {
err = evt.Content.ParseRaw(mevt.EventMessage)
if err != nil {
fmt.Println("parsing msg err", err)
} else {
c.onMessageEvent(client, evt)
}
}
fmt.Println("decrypted type", evt.Type)
}
})
// Ignore events before neb's join event.
eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID}
eventIgnorer.Register(syncer)
return nil
}

3
clients/clients_test.go

@ -76,6 +76,7 @@ func TestCommandParsing(t *testing.T) {
clients := New(&store, cli)
mxCli, _ := mautrix.NewClient("https://someplace.somewhere", "@service:user", "token")
mxCli.Client = cli
botClient := BotClient{client: mxCli}
for _, input := range commandParseTests {
executedCmdArgs = []string{}
@ -94,7 +95,7 @@ func TestCommandParsing(t *testing.T) {
RoomID: "!foo:bar",
Content: content,
}
clients.onMessageEvent(mxCli, &event)
clients.onMessageEvent(&botClient, &event)
if !reflect.DeepEqual(executedCmdArgs, input.expectArgs) {
t.Errorf("TestCommandParsing want %s, got %s", input.expectArgs, executedCmdArgs)
}

28
clients/crypto_logger.go

@ -0,0 +1,28 @@
package clients
import (
log "github.com/sirupsen/logrus"
)
// CryptoMachineLogger wraps around the usual logger, implementing the Logger interface needed by OlmMachine.
type CryptoMachineLogger struct{}
// Error formats and logs an error message.
func (CryptoMachineLogger) Error(message string, args ...interface{}) {
log.Errorf(message, args...)
}
// Warn formats and logs a warning message.
func (CryptoMachineLogger) Warn(message string, args ...interface{}) {
log.Warnf(message, args...)
}
// Debug formats and logs a debug message.
func (CryptoMachineLogger) Debug(message string, args ...interface{}) {
log.Debugf(message, args...)
}
// Trace formats and logs a trace message.
func (CryptoMachineLogger) Trace(message string, args ...interface{}) {
log.Tracef(message, args...)
}

54
clients/state_store.go

@ -0,0 +1,54 @@
package clients
import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// NebStateStore implements the StateStore interface for OlmMachine.
// It is used to determine which rooms are encrypted and which rooms are shared with a user.
// The state is updated by /sync responses.
type NebStateStore struct {
Storer *mautrix.InMemoryStore
}
// IsEncrypted returns whether a room has been encrypted.
func (ss *NebStateStore) IsEncrypted(roomID id.RoomID) bool {
room := ss.Storer.LoadRoom(roomID)
if room == nil {
return false
}
_, ok := room.State[event.StateEncryption]
return ok
}
// FindSharedRooms returns a list of room IDs that the given user ID is also a member of.
func (ss *NebStateStore) FindSharedRooms(userID id.UserID) []id.RoomID {
sharedRooms := make([]id.RoomID, 0)
for roomID, room := range ss.Storer.Rooms {
if room.GetMembershipState(userID) != event.MembershipLeave {
sharedRooms = append(sharedRooms, roomID)
}
}
return sharedRooms
}
// UpdateStateStore updates the internal state of NebStateStore from a /sync response.
func (ss *NebStateStore) UpdateStateStore(resp *mautrix.RespSync) {
for roomID, evts := range resp.Rooms.Join {
room := ss.Storer.LoadRoom(roomID)
if room == nil {
room = mautrix.NewRoom(roomID)
ss.Storer.SaveRoom(room)
}
for _, i := range evts.State.Events {
room.UpdateState(i)
}
for _, i := range evts.Timeline.Events {
if i.Type.IsState() {
room.UpdateState(i)
}
}
}
}

2
services/jira/jira.go

@ -261,7 +261,7 @@ func (s *Service) Commands(cli *mautrix.Client) []types.Command {
// be chosen arbitrarily.
func (s *Service) Expansions(cli *mautrix.Client) []types.Expansion {
return []types.Expansion{
types.Expansion{
{
Regexp: issueKeyRegex,
Expand: func(roomID id.RoomID, userID id.UserID, issueKeyGroups []string) interface{} {
return s.expandIssue(roomID, userID, issueKeyGroups)

Loading…
Cancel
Save