diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index 2ef951a..6305b90 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -7,6 +7,7 @@ import ( log "github.com/Sirupsen/logrus" "github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/types" + "io/ioutil" "net/http" "net/url" ) @@ -19,10 +20,11 @@ type githubRealm struct { } type githubSession struct { - State string - id string - userID string - realmID string + AccessToken string + Scopes string + id string + userID string + realmID string } func (s *githubSession) UserID() string { @@ -76,6 +78,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int } func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { + // parse out params from the request code := req.URL.Query().Get("code") state := req.URL.Query().Get("state") logger := log.WithFields(log.Fields{ @@ -83,19 +86,49 @@ func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request }) logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") if code == "" || state == "" { - w.WriteHeader(400) - w.Write([]byte("code and state are required")) + failWith(logger, w, 400, "code and state are required", nil) return } // load the session (we keyed off the state param) session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) if err != nil { - logger.WithError(err).Print("Failed to load session") - w.WriteHeader(400) - w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause + // most likely cause + failWith(logger, w, 400, "Provided ?state= param is not recognised.", err) return } - logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user") + ghSession := session.(*githubSession) + logger.WithField("user_id", ghSession.UserID()).Print("Mapped redirect to user") + + // exchange code for access_token + res, err := http.PostForm("https://github.com/login/oauth/access_token", + url.Values{"client_id": {r.ClientID}, "client_secret": {r.ClientSecret}, "code": {code}}) + if err != nil { + failWith(logger, w, 502, "Failed to exchange code for token", err) + return + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + failWith(logger, w, 502, "Failed to read token response", err) + return + } + vals, err := url.ParseQuery(string(body)) + if err != nil { + failWith(logger, w, 502, "Failed to parse token response", err) + return + } + + // update database and return + ghSession.AccessToken = vals.Get("access_token") + ghSession.Scopes = vals.Get("scope") + logger.WithField("scope", ghSession.Scopes).Print("Scopes granted.") + _, err = database.GetServiceDB().StoreAuthSession(ghSession) + if err != nil { + failWith(logger, w, 500, "Failed to persist session", err) + return + } + w.WriteHeader(200) + w.Write([]byte("OK!")) } func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { @@ -106,6 +139,12 @@ func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession } } +func failWith(logger *log.Entry, w http.ResponseWriter, code int, msg string, err error) { + logger.WithError(err).Print(msg) + w.WriteHeader(code) + w.Write([]byte(msg)) +} + // Generate a cryptographically secure pseudorandom string with the given number of bytes (length). // Returns a hex string of the bytes. func randomString(length int) (string, error) {