|
@ -7,6 +7,7 @@ import ( |
|
|
log "github.com/Sirupsen/logrus" |
|
|
log "github.com/Sirupsen/logrus" |
|
|
"github.com/matrix-org/go-neb/database" |
|
|
"github.com/matrix-org/go-neb/database" |
|
|
"github.com/matrix-org/go-neb/types" |
|
|
"github.com/matrix-org/go-neb/types" |
|
|
|
|
|
"io/ioutil" |
|
|
"net/http" |
|
|
"net/http" |
|
|
"net/url" |
|
|
"net/url" |
|
|
) |
|
|
) |
|
@ -19,10 +20,11 @@ type githubRealm struct { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
type githubSession struct { |
|
|
type githubSession struct { |
|
|
State string |
|
|
|
|
|
id string |
|
|
|
|
|
userID string |
|
|
|
|
|
realmID string |
|
|
|
|
|
|
|
|
AccessToken string |
|
|
|
|
|
Scopes string |
|
|
|
|
|
id string |
|
|
|
|
|
userID string |
|
|
|
|
|
realmID string |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (s *githubSession) UserID() string { |
|
|
func (s *githubSession) UserID() string { |
|
@ -76,6 +78,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { |
|
|
func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { |
|
|
|
|
|
// parse out params from the request
|
|
|
code := req.URL.Query().Get("code") |
|
|
code := req.URL.Query().Get("code") |
|
|
state := req.URL.Query().Get("state") |
|
|
state := req.URL.Query().Get("state") |
|
|
logger := log.WithFields(log.Fields{ |
|
|
logger := log.WithFields(log.Fields{ |
|
@ -83,19 +86,49 @@ func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request |
|
|
}) |
|
|
}) |
|
|
logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") |
|
|
logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") |
|
|
if code == "" || state == "" { |
|
|
if code == "" || state == "" { |
|
|
w.WriteHeader(400) |
|
|
|
|
|
w.Write([]byte("code and state are required")) |
|
|
|
|
|
|
|
|
failWith(logger, w, 400, "code and state are required", nil) |
|
|
return |
|
|
return |
|
|
} |
|
|
} |
|
|
// load the session (we keyed off the state param)
|
|
|
// load the session (we keyed off the state param)
|
|
|
session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) |
|
|
session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) |
|
|
if err != nil { |
|
|
if err != nil { |
|
|
logger.WithError(err).Print("Failed to load session") |
|
|
|
|
|
w.WriteHeader(400) |
|
|
|
|
|
w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause
|
|
|
|
|
|
|
|
|
// most likely cause
|
|
|
|
|
|
failWith(logger, w, 400, "Provided ?state= param is not recognised.", err) |
|
|
return |
|
|
return |
|
|
} |
|
|
} |
|
|
logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user") |
|
|
|
|
|
|
|
|
ghSession := session.(*githubSession) |
|
|
|
|
|
logger.WithField("user_id", ghSession.UserID()).Print("Mapped redirect to user") |
|
|
|
|
|
|
|
|
|
|
|
// exchange code for access_token
|
|
|
|
|
|
res, err := http.PostForm("https://github.com/login/oauth/access_token", |
|
|
|
|
|
url.Values{"client_id": {r.ClientID}, "client_secret": {r.ClientSecret}, "code": {code}}) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
failWith(logger, w, 502, "Failed to exchange code for token", err) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
defer res.Body.Close() |
|
|
|
|
|
body, err := ioutil.ReadAll(res.Body) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
failWith(logger, w, 502, "Failed to read token response", err) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
vals, err := url.ParseQuery(string(body)) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
failWith(logger, w, 502, "Failed to parse token response", err) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// update database and return
|
|
|
|
|
|
ghSession.AccessToken = vals.Get("access_token") |
|
|
|
|
|
ghSession.Scopes = vals.Get("scope") |
|
|
|
|
|
logger.WithField("scope", ghSession.Scopes).Print("Scopes granted.") |
|
|
|
|
|
_, err = database.GetServiceDB().StoreAuthSession(ghSession) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
failWith(logger, w, 500, "Failed to persist session", err) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
w.WriteHeader(200) |
|
|
|
|
|
w.Write([]byte("OK!")) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { |
|
|
func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { |
|
@ -106,6 +139,12 @@ func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func failWith(logger *log.Entry, w http.ResponseWriter, code int, msg string, err error) { |
|
|
|
|
|
logger.WithError(err).Print(msg) |
|
|
|
|
|
w.WriteHeader(code) |
|
|
|
|
|
w.Write([]byte(msg)) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// Generate a cryptographically secure pseudorandom string with the given number of bytes (length).
|
|
|
// Generate a cryptographically secure pseudorandom string with the given number of bytes (length).
|
|
|
// Returns a hex string of the bytes.
|
|
|
// Returns a hex string of the bytes.
|
|
|
func randomString(length int) (string, error) { |
|
|
func randomString(length int) (string, error) { |
|
|