You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

153 lines
4.4 KiB

8 years ago
8 years ago
  1. package clients
  2. import (
  3. "fmt"
  4. "net/http"
  5. "reflect"
  6. "sync"
  7. "testing"
  8. "time"
  9. "github.com/matrix-org/go-neb/database"
  10. "github.com/matrix-org/go-neb/types"
  11. "maunium.net/go/mautrix"
  12. "maunium.net/go/mautrix/crypto"
  13. mevt "maunium.net/go/mautrix/event"
  14. "maunium.net/go/mautrix/id"
  15. )
  16. var commandParseTests = []struct {
  17. body string
  18. expectArgs []string
  19. }{
  20. {"!test word", []string{"word"}},
  21. {"!test two words", []string{"two", "words"}},
  22. {`!test "words with double quotes"`, []string{"words with double quotes"}},
  23. {"!test 'words with single quotes'", []string{"words with single quotes"}},
  24. {`!test 'single quotes' "double quotes"`, []string{"single quotes", "double quotes"}},
  25. {`!test ‘smart single quotes’ “smart double quotes”`, []string{"smart single quotes", "smart double quotes"}},
  26. }
  27. type MockService struct {
  28. types.DefaultService
  29. commands []types.Command
  30. }
  31. func (s *MockService) Commands(cli types.MatrixClient) []types.Command {
  32. return s.commands
  33. }
  34. type MockStore struct {
  35. database.NopStorage
  36. service types.Service
  37. }
  38. func (d *MockStore) LoadServicesForUser(userID id.UserID) ([]types.Service, error) {
  39. return []types.Service{d.service}, nil
  40. }
  41. type MockTransport struct {
  42. roundTrip func(*http.Request) (*http.Response, error)
  43. }
  44. func (t MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  45. return t.roundTrip(req)
  46. }
  47. func TestCommandParsing(t *testing.T) {
  48. var executedCmdArgs []string
  49. cmds := []types.Command{
  50. types.Command{
  51. Path: []string{"test"},
  52. Command: func(roomID id.RoomID, userID id.UserID, args []string) (interface{}, error) {
  53. executedCmdArgs = args
  54. return nil, nil
  55. },
  56. },
  57. }
  58. s := MockService{commands: cmds}
  59. store := MockStore{service: &s}
  60. database.SetServiceDB(&store)
  61. trans := struct{ MockTransport }{}
  62. trans.roundTrip = func(*http.Request) (*http.Response, error) {
  63. return nil, fmt.Errorf("unhandled test path")
  64. }
  65. cli := &http.Client{
  66. Transport: trans,
  67. }
  68. clients := New(&store, cli)
  69. mxCli, _ := mautrix.NewClient("https://someplace.somewhere", "@service:user", "token")
  70. mxCli.Client = cli
  71. botClient := BotClient{Client: mxCli}
  72. for _, input := range commandParseTests {
  73. executedCmdArgs = []string{}
  74. content := mevt.Content{Raw: map[string]interface{}{
  75. "body": input.body,
  76. "msgtype": "m.text",
  77. }}
  78. if veryRaw, err := content.MarshalJSON(); err != nil {
  79. t.Errorf("Error marshalling JSON: %s", err)
  80. } else {
  81. content.VeryRaw = veryRaw
  82. }
  83. content.ParseRaw(mevt.EventMessage)
  84. event := mevt.Event{
  85. Type: mevt.EventMessage,
  86. Sender: "@someone:somewhere",
  87. RoomID: "!foo:bar",
  88. Content: content,
  89. }
  90. clients.onMessageEvent(&botClient, &event)
  91. if !reflect.DeepEqual(executedCmdArgs, input.expectArgs) {
  92. t.Errorf("TestCommandParsing want %s, got %s", input.expectArgs, executedCmdArgs)
  93. }
  94. }
  95. }
  96. func TestSASVerificationHandling(t *testing.T) {
  97. botClient := BotClient{verificationSAS: &sync.Map{}}
  98. botClient.olmMachine = &crypto.OlmMachine{
  99. DefaultSASTimeout: time.Minute,
  100. }
  101. otherUserID := id.UserID("otherUser")
  102. otherDeviceID := id.DeviceID("otherDevice")
  103. otherDevice := &crypto.DeviceIdentity{
  104. UserID: otherUserID,
  105. DeviceID: otherDeviceID,
  106. }
  107. botClient.SubmitDecimalSAS(otherUserID, otherDeviceID, crypto.DecimalSASData([3]uint{4, 5, 6}))
  108. matched := botClient.VerifySASMatch(otherDevice, crypto.DecimalSASData([3]uint{1, 2, 3}))
  109. if matched {
  110. t.Error("SAS matched when they shouldn't have")
  111. }
  112. botClient.SubmitDecimalSAS(otherUserID, otherDeviceID, crypto.DecimalSASData([3]uint{1, 2, 3}))
  113. matched = botClient.VerifySASMatch(otherDevice, crypto.DecimalSASData([3]uint{1, 2, 3}))
  114. if !matched {
  115. t.Error("Expected SAS to match but they didn't")
  116. }
  117. botClient.SubmitDecimalSAS(otherUserID+"wrong", otherDeviceID, crypto.DecimalSASData([3]uint{4, 5, 6}))
  118. finished := make(chan bool)
  119. go func() {
  120. matched := botClient.VerifySASMatch(otherDevice, crypto.DecimalSASData([3]uint{1, 2, 3}))
  121. finished <- true
  122. if !matched {
  123. t.Error("SAS didn't match when it should have (receiving SAS after calling verification func)")
  124. }
  125. }()
  126. select {
  127. case <-finished:
  128. t.Error("Verification finished before receiving the SAS from the correct user")
  129. default:
  130. }
  131. botClient.SubmitDecimalSAS(otherUserID, otherDeviceID, crypto.DecimalSASData([3]uint{1, 2, 3}))
  132. select {
  133. case <-finished:
  134. case <-time.After(10 * time.Second):
  135. t.Error("Verification did not finish after receiving the SAS from the correct user")
  136. }
  137. }