Browse Source

Store access_tokens in the DB

kegan/github-auth
Kegan Dougal 9 years ago
parent
commit
ba803ccd00
  1. 59
      src/github.com/matrix-org/go-neb/realms/github/github.go

59
src/github.com/matrix-org/go-neb/realms/github/github.go

@ -7,6 +7,7 @@ import (
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/types" "github.com/matrix-org/go-neb/types"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
) )
@ -19,10 +20,11 @@ type githubRealm struct {
} }
type githubSession struct { type githubSession struct {
State string
id string
userID string
realmID string
AccessToken string
Scopes string
id string
userID string
realmID string
} }
func (s *githubSession) UserID() string { func (s *githubSession) UserID() string {
@ -76,6 +78,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int
} }
func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) {
// parse out params from the request
code := req.URL.Query().Get("code") code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state") state := req.URL.Query().Get("state")
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
@ -83,19 +86,49 @@ func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request
}) })
logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect")
if code == "" || state == "" { if code == "" || state == "" {
w.WriteHeader(400)
w.Write([]byte("code and state are required"))
failWith(logger, w, 400, "code and state are required", nil)
return return
} }
// load the session (we keyed off the state param) // load the session (we keyed off the state param)
session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state)
if err != nil { if err != nil {
logger.WithError(err).Print("Failed to load session")
w.WriteHeader(400)
w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause
// most likely cause
failWith(logger, w, 400, "Provided ?state= param is not recognised.", err)
return return
} }
logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user")
ghSession := session.(*githubSession)
logger.WithField("user_id", ghSession.UserID()).Print("Mapped redirect to user")
// exchange code for access_token
res, err := http.PostForm("https://github.com/login/oauth/access_token",
url.Values{"client_id": {r.ClientID}, "client_secret": {r.ClientSecret}, "code": {code}})
if err != nil {
failWith(logger, w, 502, "Failed to exchange code for token", err)
return
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
failWith(logger, w, 502, "Failed to read token response", err)
return
}
vals, err := url.ParseQuery(string(body))
if err != nil {
failWith(logger, w, 502, "Failed to parse token response", err)
return
}
// update database and return
ghSession.AccessToken = vals.Get("access_token")
ghSession.Scopes = vals.Get("scope")
logger.WithField("scope", ghSession.Scopes).Print("Scopes granted.")
_, err = database.GetServiceDB().StoreAuthSession(ghSession)
if err != nil {
failWith(logger, w, 500, "Failed to persist session", err)
return
}
w.WriteHeader(200)
w.Write([]byte("OK!"))
} }
func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession {
@ -106,6 +139,12 @@ func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession
} }
} }
func failWith(logger *log.Entry, w http.ResponseWriter, code int, msg string, err error) {
logger.WithError(err).Print(msg)
w.WriteHeader(code)
w.Write([]byte(msg))
}
// Generate a cryptographically secure pseudorandom string with the given number of bytes (length). // Generate a cryptographically secure pseudorandom string with the given number of bytes (length).
// Returns a hex string of the bytes. // Returns a hex string of the bytes.
func randomString(length int) (string, error) { func randomString(length int) (string, error) {

Loading…
Cancel
Save