diff --git a/src/github.com/matrix-org/go-neb/api.go b/src/github.com/matrix-org/go-neb/api.go index cc419b2..084ea43 100644 --- a/src/github.com/matrix-org/go-neb/api.go +++ b/src/github.com/matrix-org/go-neb/api.go @@ -51,6 +51,23 @@ func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interf return response, nil } +type realmRedirectHandler struct { + db *database.ServiceDB +} + +func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { + segments := strings.Split(req.URL.Path, "/") + // last path segment is the realm ID which we will pass the incoming request to + realmID := segments[len(segments)-1] + realm, err := rh.db.LoadAuthRealm(realmID) + if err != nil { + log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") + w.WriteHeader(404) + return + } + realm.OnReceiveRedirect(w, req) +} + type configureAuthRealmHandler struct { db *database.ServiceDB } diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index 9ade230..7e8fea2 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -206,7 +206,7 @@ func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, // The previous session, if any, is returned. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - old, err = selectAuthSessionTxn(txn, session.RealmID(), session.UserID()) + old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID()) if err == sql.ErrNoRows { return insertAuthSessionTxn(txn, time.Now(), session) } else if err != nil { @@ -218,11 +218,23 @@ func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthS return } -// LoadAuthSession loads an AuthSession from the database. +// LoadAuthSessionByUser loads an AuthSession from the database based on the given +// realm and user ID. // Returns sql.ErrNoRows if the session isn't in the database. -func (d *ServiceDB) LoadAuthSession(realmID, userID string) (session types.AuthSession, err error) { +func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - session, err = selectAuthSessionTxn(txn, realmID, userID) + session, err = selectAuthSessionByUserTxn(txn, realmID, userID) + return err + }) + return +} + +// LoadAuthSessionByID loads an AuthSession from the database based on the given +// realm and session ID. +// Returns sql.ErrNoRows if the session isn't in the database. +func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID) return err }) return diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index 3ec01e6..8b16c43 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -45,12 +45,14 @@ CREATE TABLE IF NOT EXISTS auth_realms ( ); CREATE TABLE IF NOT EXISTS auth_sessions ( + session_id TEXT NOT NULL, realm_id TEXT NOT NULL, user_id TEXT NOT NULL, session_json TEXT NOT NULL, time_added_ms BIGINT NOT NULL, time_updated_ms BIGINT NOT NULL, - UNIQUE(realm_id, user_id) + UNIQUE(realm_id, user_id), + UNIQUE(realm_id, session_id) ); ` @@ -280,8 +282,8 @@ func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error { const insertAuthSessionSQL = ` INSERT INTO auth_sessions( - realm_id, user_id, session_json, time_added_ms, time_updated_ms -) VALUES ($1, $2, $3, $4, $5) + session_id, realm_id, user_id, session_json, time_added_ms, time_updated_ms +) VALUES ($1, $2, $3, $4, $5, $6) ` func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -292,23 +294,56 @@ func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) t := now.UnixNano() / 1000000 _, err = txn.Exec( insertAuthSessionSQL, - session.RealmID(), session.UserID(), sessionJSON, t, t, + session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t, ) return err } -const selectAuthSessionSQL = ` -SELECT realm_type, realm_json, session_json FROM auth_sessions +const selectAuthSessionByUserSQL = ` +SELECT session_id, realm_type, realm_json, session_json FROM auth_sessions JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2 ` -func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { +func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { + var id string + var realmType string + var realmJSON []byte + var sessionJSON []byte + row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID) + if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil { + return nil, err + } + realm := types.CreateAuthRealm(realmID, realmType) + if realm == nil { + return nil, fmt.Errorf("Cannot create realm of type %s", realmType) + } + if err := json.Unmarshal(realmJSON, realm); err != nil { + return nil, err + } + session := realm.AuthSession(id, userID, realmID) + if session == nil { + return nil, fmt.Errorf("Cannot create session for given realm") + } + if err := json.Unmarshal(sessionJSON, session); err != nil { + return nil, err + } + return session, nil +} + +const selectAuthSessionByIDSQL = ` +SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions + JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id + WHERE auth_sessions.realm_id = $1 AND auth_sessions.session_id = $2 +` + +func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSession, error) { + var userID string var realmType string var realmJSON []byte var sessionJSON []byte - row := txn.QueryRow(selectAuthSessionSQL, realmID, userID) - if err := row.Scan(&realmType, &realmJSON, &sessionJSON); err != nil { + row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, id) + if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } realm := types.CreateAuthRealm(realmID, realmType) @@ -318,7 +353,7 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio if err := json.Unmarshal(realmJSON, realm); err != nil { return nil, err } - session := realm.AuthSession(userID, realmID) + session := realm.AuthSession(id, userID, realmID) if session == nil { return nil, fmt.Errorf("Cannot create session for given realm") } @@ -329,8 +364,8 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio } const updateAuthSessionSQL = ` -UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2 - WHERE realm_id=$3 AND user_id=$4 +UPDATE auth_sessions SET session_id=$1, session_json=$2, time_updated_ms=$3 + WHERE realm_id=$4 AND user_id=$5 ` func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -340,7 +375,7 @@ func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) } t := now.UnixNano() / 1000000 _, err = txn.Exec( - updateAuthSessionSQL, sessionJSON, t, + updateAuthSessionSQL, session.ID(), sessionJSON, t, session.RealmID(), session.UserID(), ) return err diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index ca066a9..692d2e2 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -37,6 +37,8 @@ func main() { http.Handle("/admin/requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db})) wh := &webhookHandler{db: db, clients: clients} http.HandleFunc("/services/hooks/", wh.handle) + rh := &realmRedirectHandler{db: db} + http.HandleFunc("/realms/redirects/", rh.handle) http.ListenAndServe(bindAddress, nil) } diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index 552e059..b928282 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -1,10 +1,14 @@ package realms import ( + "crypto/rand" + "encoding/hex" "encoding/json" log "github.com/Sirupsen/logrus" "github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/types" + "io/ioutil" + "net/http" "net/url" ) @@ -12,13 +16,15 @@ type githubRealm struct { id string ClientSecret string ClientID string - WebhookEndpoint string + RedirectBaseURI string } type githubSession struct { - State string - userID string - realmID string + AccessToken string + Scopes string + id string + userID string + realmID string } func (s *githubSession) UserID() string { @@ -29,6 +35,10 @@ func (s *githubSession) RealmID() string { return s.realmID } +func (s *githubSession) ID() string { + return s.id +} + func (r *githubRealm) ID() string { return r.id } @@ -38,18 +48,25 @@ func (r *githubRealm) Type() string { } func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) interface{} { + state, err := randomString(10) + if err != nil { + log.WithError(err).Print("Failed to generate state param") + return nil + } u, _ := url.Parse("https://github.com/login/oauth/authorize") q := u.Query() q.Set("client_id", r.ClientID) q.Set("client_secret", r.ClientSecret) - // TODO: state, scope + q.Set("state", state) + // TODO: Path is from goneb.go - we should probably factor it out. + q.Set("redirect_uri", r.RedirectBaseURI+"/realms/redirects/"+r.ID()) u.RawQuery = q.Encode() session := &githubSession{ - State: "TODO", + id: state, // key off the state for redirects userID: userID, realmID: r.ID(), } - _, err := database.GetServiceDB().StoreAuthSession(session) + _, err = database.GetServiceDB().StoreAuthSession(session) if err != nil { log.WithError(err).Print("Failed to store new auth session") return nil @@ -60,13 +77,89 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int }{u.String()} } -func (r *githubRealm) AuthSession(userID, realmID string) types.AuthSession { +func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { + // parse out params from the request + code := req.URL.Query().Get("code") + state := req.URL.Query().Get("state") + logger := log.WithFields(log.Fields{ + "state": state, + }) + logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") + if code == "" || state == "" { + failWith(logger, w, 400, "code and state are required", nil) + return + } + // load the session (we keyed off the state param) + session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) + if err != nil { + // most likely cause + failWith(logger, w, 400, "Provided ?state= param is not recognised.", err) + return + } + ghSession, ok := session.(*githubSession) + if !ok { + failWith(logger, w, 500, "Unexpected session found.", nil) + return + } + logger.WithField("user_id", ghSession.UserID()).Print("Mapped redirect to user") + + // exchange code for access_token + res, err := http.PostForm("https://github.com/login/oauth/access_token", + url.Values{"client_id": {r.ClientID}, "client_secret": {r.ClientSecret}, "code": {code}}) + if err != nil { + failWith(logger, w, 502, "Failed to exchange code for token", err) + return + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + failWith(logger, w, 502, "Failed to read token response", err) + return + } + vals, err := url.ParseQuery(string(body)) + if err != nil { + failWith(logger, w, 502, "Failed to parse token response", err) + return + } + + // update database and return + ghSession.AccessToken = vals.Get("access_token") + ghSession.Scopes = vals.Get("scope") + logger.WithField("scope", ghSession.Scopes).Print("Scopes granted.") + _, err = database.GetServiceDB().StoreAuthSession(ghSession) + if err != nil { + failWith(logger, w, 500, "Failed to persist session", err) + return + } + w.WriteHeader(200) + w.Write([]byte("OK!")) +} + +func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { return &githubSession{ + id: id, userID: userID, realmID: realmID, } } +func failWith(logger *log.Entry, w http.ResponseWriter, code int, msg string, err error) { + logger.WithError(err).Print(msg) + w.WriteHeader(code) + w.Write([]byte(msg)) +} + +// Generate a cryptographically secure pseudorandom string with the given number of bytes (length). +// Returns a hex string of the bytes. +func randomString(length int) (string, error) { + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + func init() { types.RegisterAuthRealm(func(realmID string) types.AuthRealm { return &githubRealm{id: realmID} diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index 47c801c..fe6c453 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -59,7 +59,8 @@ func CreateService(serviceID, serviceType string) Service { type AuthRealm interface { ID() string Type() string - AuthSession(userID, realmID string) AuthSession + OnReceiveRedirect(w http.ResponseWriter, req *http.Request) + AuthSession(id, userID, realmID string) AuthSession RequestAuthSession(userID string, config json.RawMessage) interface{} } @@ -83,6 +84,7 @@ func CreateAuthRealm(realmID, realmType string) AuthRealm { // AuthSession represents a single authentication session between a user and // an auth realm. type AuthSession interface { + ID() string UserID() string RealmID() string }