From 71be73a721b619255a407f7a3762684de20e1978 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 5 Aug 2016 11:38:34 +0100 Subject: [PATCH] Add ID field to AuthSessions so redirects can lookup based off this key Required for Github OAuth redirect requests and just is generally useful to have. Add UNIQUE constraints on realm/user and realm/id to prevent multiple users getting the same ID. --- .../matrix-org/go-neb/database/db.go | 19 ++++-- .../matrix-org/go-neb/database/schema.go | 61 +++++++++++++++---- .../matrix-org/go-neb/realms/github/github.go | 25 ++++++-- .../matrix-org/go-neb/types/types.go | 3 +- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index de4741d..7e8fea2 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -206,7 +206,7 @@ func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, // The previous session, if any, is returned. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - old, err = selectAuthSessionTxn(txn, session.RealmID(), session.UserID()) + old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID()) if err == sql.ErrNoRows { return insertAuthSessionTxn(txn, time.Now(), session) } else if err != nil { @@ -218,12 +218,23 @@ func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthS return } -// LoadAuthSessionForUser loads an AuthSession from the database based on the given +// LoadAuthSessionByUser loads an AuthSession from the database based on the given // realm and user ID. // Returns sql.ErrNoRows if the session isn't in the database. -func (d *ServiceDB) LoadAuthSessionForUser(realmID, userID string) (session types.AuthSession, err error) { +func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - session, err = selectAuthSessionTxn(txn, realmID, userID) + session, err = selectAuthSessionByUserTxn(txn, realmID, userID) + return err + }) + return +} + +// LoadAuthSessionByID loads an AuthSession from the database based on the given +// realm and session ID. +// Returns sql.ErrNoRows if the session isn't in the database. +func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID) return err }) return diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index 3ec01e6..64f5812 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -45,12 +45,14 @@ CREATE TABLE IF NOT EXISTS auth_realms ( ); CREATE TABLE IF NOT EXISTS auth_sessions ( + id TEXT NOT NULL, realm_id TEXT NOT NULL, user_id TEXT NOT NULL, session_json TEXT NOT NULL, time_added_ms BIGINT NOT NULL, time_updated_ms BIGINT NOT NULL, - UNIQUE(realm_id, user_id) + UNIQUE(realm_id, user_id), + UNIQUE(realm_id, id) ); ` @@ -280,8 +282,8 @@ func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error { const insertAuthSessionSQL = ` INSERT INTO auth_sessions( - realm_id, user_id, session_json, time_added_ms, time_updated_ms -) VALUES ($1, $2, $3, $4, $5) + id, realm_id, user_id, session_json, time_added_ms, time_updated_ms +) VALUES ($1, $2, $3, $4, $5, $6) ` func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -292,23 +294,56 @@ func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) t := now.UnixNano() / 1000000 _, err = txn.Exec( insertAuthSessionSQL, - session.RealmID(), session.UserID(), sessionJSON, t, t, + session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t, ) return err } -const selectAuthSessionSQL = ` -SELECT realm_type, realm_json, session_json FROM auth_sessions +const selectAuthSessionByUserSQL = ` +SELECT id, realm_type, realm_json, session_json FROM auth_sessions JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2 ` -func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { +func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { + var id string + var realmType string + var realmJSON []byte + var sessionJSON []byte + row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID) + if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil { + return nil, err + } + realm := types.CreateAuthRealm(realmID, realmType) + if realm == nil { + return nil, fmt.Errorf("Cannot create realm of type %s", realmType) + } + if err := json.Unmarshal(realmJSON, realm); err != nil { + return nil, err + } + session := realm.AuthSession(id, userID, realmID) + if session == nil { + return nil, fmt.Errorf("Cannot create session for given realm") + } + if err := json.Unmarshal(sessionJSON, session); err != nil { + return nil, err + } + return session, nil +} + +const selectAuthSessionByIDSQL = ` +SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions + JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id + WHERE auth_sessions.realm_id = $1 AND auth_sessions.id = $2 +` + +func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSession, error) { + var userID string var realmType string var realmJSON []byte var sessionJSON []byte - row := txn.QueryRow(selectAuthSessionSQL, realmID, userID) - if err := row.Scan(&realmType, &realmJSON, &sessionJSON); err != nil { + row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, id) + if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } realm := types.CreateAuthRealm(realmID, realmType) @@ -318,7 +353,7 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio if err := json.Unmarshal(realmJSON, realm); err != nil { return nil, err } - session := realm.AuthSession(userID, realmID) + session := realm.AuthSession(id, userID, realmID) if session == nil { return nil, fmt.Errorf("Cannot create session for given realm") } @@ -329,8 +364,8 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio } const updateAuthSessionSQL = ` -UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2 - WHERE realm_id=$3 AND user_id=$4 +UPDATE auth_sessions SET id=$1, session_json=$2, time_updated_ms=$3 + WHERE realm_id=$4 AND user_id=$5 ` func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -340,7 +375,7 @@ func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) } t := now.UnixNano() / 1000000 _, err = txn.Exec( - updateAuthSessionSQL, sessionJSON, t, + updateAuthSessionSQL, session.ID(), sessionJSON, t, session.RealmID(), session.UserID(), ) return err diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index b763768..2ef951a 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -20,6 +20,7 @@ type githubRealm struct { type githubSession struct { State string + id string userID string realmID string } @@ -32,6 +33,10 @@ func (s *githubSession) RealmID() string { return s.realmID } +func (s *githubSession) ID() string { + return s.id +} + func (r *githubRealm) ID() string { return r.id } @@ -55,7 +60,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int q.Set("redirect_uri", r.RedirectBaseURI+"/realms/redirects/"+r.ID()) u.RawQuery = q.Encode() session := &githubSession{ - State: state, + id: state, // key off the state for redirects userID: userID, realmID: r.ID(), } @@ -73,19 +78,29 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { code := req.URL.Query().Get("code") state := req.URL.Query().Get("state") - log.WithFields(log.Fields{ - "code": code, + logger := log.WithFields(log.Fields{ "state": state, - }).Print("GithubRealm: OnReceiveRedirect") + }) + logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") if code == "" || state == "" { w.WriteHeader(400) w.Write([]byte("code and state are required")) return } + // load the session (we keyed off the state param) + session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) + if err != nil { + logger.WithError(err).Print("Failed to load session") + w.WriteHeader(400) + w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause + return + } + logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user") } -func (r *githubRealm) AuthSession(userID, realmID string) types.AuthSession { +func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { return &githubSession{ + id: id, userID: userID, realmID: realmID, } diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index f72e110..fe6c453 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -60,7 +60,7 @@ type AuthRealm interface { ID() string Type() string OnReceiveRedirect(w http.ResponseWriter, req *http.Request) - AuthSession(userID, realmID string) AuthSession + AuthSession(id, userID, realmID string) AuthSession RequestAuthSession(userID string, config json.RawMessage) interface{} } @@ -84,6 +84,7 @@ func CreateAuthRealm(realmID, realmType string) AuthRealm { // AuthSession represents a single authentication session between a user and // an auth realm. type AuthSession interface { + ID() string UserID() string RealmID() string }