|
@ -16,6 +16,7 @@ import ( |
|
|
shellwords "github.com/mattn/go-shellwords" |
|
|
shellwords "github.com/mattn/go-shellwords" |
|
|
log "github.com/sirupsen/logrus" |
|
|
log "github.com/sirupsen/logrus" |
|
|
"maunium.net/go/mautrix" |
|
|
"maunium.net/go/mautrix" |
|
|
|
|
|
"maunium.net/go/mautrix/crypto" |
|
|
mevt "maunium.net/go/mautrix/event" |
|
|
mevt "maunium.net/go/mautrix/event" |
|
|
"maunium.net/go/mautrix/id" |
|
|
"maunium.net/go/mautrix/id" |
|
|
) |
|
|
) |
|
@ -72,8 +73,9 @@ func (c *Clients) Start() error { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
type clientEntry struct { |
|
|
type clientEntry struct { |
|
|
config api.ClientConfig |
|
|
|
|
|
client *mautrix.Client |
|
|
|
|
|
|
|
|
config api.ClientConfig |
|
|
|
|
|
client *mautrix.Client |
|
|
|
|
|
olmMachine *crypto.OlmMachine |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Clients) getClient(userID id.UserID) clientEntry { |
|
|
func (c *Clients) getClient(userID id.UserID) clientEntry { |
|
@ -104,7 +106,7 @@ func (c *Clients) loadClientFromDB(userID id.UserID) (entry clientEntry, err err |
|
|
return |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if entry.client, err = c.newClient(entry.config); err != nil { |
|
|
|
|
|
|
|
|
if err = c.initClient(&entry); err != nil { |
|
|
return |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -125,7 +127,7 @@ func (c *Clients) updateClientInDB(newConfig api.ClientConfig) (new clientEntry, |
|
|
|
|
|
|
|
|
new.config = newConfig |
|
|
new.config = newConfig |
|
|
|
|
|
|
|
|
if new.client, err = c.newClient(new.config); err != nil { |
|
|
|
|
|
|
|
|
if err = c.initClient(&new); err != nil { |
|
|
return |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -206,7 +208,35 @@ func (c *Clients) onMessageEvent(client *mautrix.Client, event *mevt.Event) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for _, content := range responses { |
|
|
for _, content := range responses { |
|
|
if _, err := client.SendMessageEvent(event.RoomID, mevt.EventMessage, content); err != nil { |
|
|
|
|
|
|
|
|
evtType := mevt.EventMessage |
|
|
|
|
|
curClient := c.clients[client.UserID] |
|
|
|
|
|
olmMachine := curClient.olmMachine |
|
|
|
|
|
if olmMachine.StateStore.IsEncrypted(event.RoomID) { |
|
|
|
|
|
fmt.Println(event.RoomID, "is enc") |
|
|
|
|
|
if sess, err := olmMachine.CryptoStore.GetOutboundGroupSession(event.RoomID); err != nil { |
|
|
|
|
|
fmt.Println("Error getting outbound", err) |
|
|
|
|
|
} else if sess == nil { |
|
|
|
|
|
if membs, err := client.JoinedMembers(event.RoomID); err != nil { |
|
|
|
|
|
fmt.Println(err) |
|
|
|
|
|
} else { |
|
|
|
|
|
memberIDs := make([]id.UserID, 0, len(membs.Joined)) |
|
|
|
|
|
for member := range membs.Joined { |
|
|
|
|
|
memberIDs = append(memberIDs, member) |
|
|
|
|
|
} |
|
|
|
|
|
if err = olmMachine.ShareGroupSession(event.RoomID, memberIDs); err != nil { |
|
|
|
|
|
fmt.Println(err) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
msgContent := mevt.Content{Parsed: content} |
|
|
|
|
|
if enc, err := olmMachine.EncryptMegolmEvent(event.RoomID, mevt.EventMessage, msgContent); err != nil { |
|
|
|
|
|
fmt.Println("error encoding", err) |
|
|
|
|
|
} else { |
|
|
|
|
|
content = enc |
|
|
|
|
|
evtType = mevt.EventEncrypted |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if _, err := client.SendMessageEvent(event.RoomID, evtType, content); err != nil { |
|
|
log.WithFields(log.Fields{ |
|
|
log.WithFields(log.Fields{ |
|
|
log.ErrorKey: err, |
|
|
log.ErrorKey: err, |
|
|
"room_id": event.RoomID, |
|
|
"room_id": event.RoomID, |
|
@ -341,35 +371,37 @@ func (c *Clients) onRoomMemberEvent(client *mautrix.Client, event *mevt.Event) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (c *Clients) newClient(config api.ClientConfig) (*mautrix.Client, error) { |
|
|
|
|
|
|
|
|
func (c *Clients) initClient(clientEntry *clientEntry) error { |
|
|
|
|
|
config := clientEntry.config |
|
|
client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) |
|
|
client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) |
|
|
if err != nil { |
|
|
if err != nil { |
|
|
return nil, err |
|
|
|
|
|
|
|
|
return err |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
client.Client = c.httpClient |
|
|
client.Client = c.httpClient |
|
|
client.DeviceID = config.DeviceID |
|
|
client.DeviceID = config.DeviceID |
|
|
syncer := client.Syncer.(*mautrix.DefaultSyncer) |
|
|
syncer := client.Syncer.(*mautrix.DefaultSyncer) |
|
|
|
|
|
|
|
|
nebStore := &matrix.NEBStore{ |
|
|
nebStore := &matrix.NEBStore{ |
|
|
InMemoryStore: *mautrix.NewInMemoryStore(), |
|
|
InMemoryStore: *mautrix.NewInMemoryStore(), |
|
|
Database: c.db, |
|
|
Database: c.db, |
|
|
ClientConfig: config, |
|
|
ClientConfig: config, |
|
|
} |
|
|
} |
|
|
client.Store = nebStore |
|
|
client.Store = nebStore |
|
|
syncer.Store = nebStore |
|
|
|
|
|
|
|
|
|
|
|
// TODO: Check that the access token is valid for the userID by peforming
|
|
|
// TODO: Check that the access token is valid for the userID by peforming
|
|
|
// a request against the server.
|
|
|
// a request against the server.
|
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.EventMessage, func(event *mevt.Event) { |
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.EventMessage, func(_ mautrix.EventSource, event *mevt.Event) { |
|
|
c.onMessageEvent(client, event) |
|
|
c.onMessageEvent(client, event) |
|
|
}) |
|
|
}) |
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(event *mevt.Event) { |
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.Type{Type: "m.room.bot.options", Class: mevt.UnknownEventType}, func(_ mautrix.EventSource, event *mevt.Event) { |
|
|
c.onBotOptionsEvent(client, event) |
|
|
c.onBotOptionsEvent(client, event) |
|
|
}) |
|
|
}) |
|
|
|
|
|
|
|
|
if config.AutoJoinRooms { |
|
|
if config.AutoJoinRooms { |
|
|
syncer.OnEventType(mevt.StateMember, func(event *mevt.Event) { |
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, event *mevt.Event) { |
|
|
c.onRoomMemberEvent(client, event) |
|
|
c.onRoomMemberEvent(client, event) |
|
|
}) |
|
|
}) |
|
|
} |
|
|
} |
|
@ -398,5 +430,53 @@ func (c *Clients) newClient(config api.ClientConfig) (*mautrix.Client, error) { |
|
|
}() |
|
|
}() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return client, nil |
|
|
|
|
|
|
|
|
clientEntry.client = client |
|
|
|
|
|
|
|
|
|
|
|
gobStore, err := crypto.NewGobStore("crypto.gob") |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
return err |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
stateStore := StateStore{&nebStore.InMemoryStore} |
|
|
|
|
|
olmMachine := crypto.NewOlmMachine(client, CryptoMachineLogger{}, gobStore, &stateStore) |
|
|
|
|
|
olmMachine.Load() |
|
|
|
|
|
clientEntry.olmMachine = olmMachine |
|
|
|
|
|
|
|
|
|
|
|
syncer.OnSync(stateStore.UpdateStateStore) |
|
|
|
|
|
// Process sync response with olm machine
|
|
|
|
|
|
syncer.OnSync(func(resp *mautrix.RespSync, since string) bool { |
|
|
|
|
|
olmMachine.ProcessSyncResponse(resp, since) |
|
|
|
|
|
if err := olmMachine.CryptoStore.Flush(); err != nil { |
|
|
|
|
|
fmt.Println("cryptostore flush err", err) |
|
|
|
|
|
} |
|
|
|
|
|
return true |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.StateMember, func(_ mautrix.EventSource, evt *mevt.Event) { |
|
|
|
|
|
olmMachine.HandleMemberEvent(evt) |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
syncer.OnEventType(mevt.EventEncrypted, func(source mautrix.EventSource, evt *mevt.Event) { |
|
|
|
|
|
evt.Content.ParseRaw(mevt.EventEncrypted) |
|
|
|
|
|
evt, err := olmMachine.DecryptMegolmEvent(evt) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
fmt.Println("decryption err", err) |
|
|
|
|
|
} else { |
|
|
|
|
|
if evt.Type == mevt.EventMessage { |
|
|
|
|
|
err = evt.Content.ParseRaw(mevt.EventMessage) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
fmt.Println("parsing msg err", err) |
|
|
|
|
|
} else { |
|
|
|
|
|
c.onMessageEvent(client, evt) |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
fmt.Println("decrypted type", evt.Type) |
|
|
|
|
|
} |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
// Ignore events before neb's join event.
|
|
|
|
|
|
eventIgnorer := mautrix.OldEventIgnorer{UserID: config.UserID} |
|
|
|
|
|
eventIgnorer.Register(syncer) |
|
|
|
|
|
|
|
|
|
|
|
return nil |
|
|
} |
|
|
} |