diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index de4741d..7e8fea2 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -206,7 +206,7 @@ func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, // The previous session, if any, is returned. func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - old, err = selectAuthSessionTxn(txn, session.RealmID(), session.UserID()) + old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID()) if err == sql.ErrNoRows { return insertAuthSessionTxn(txn, time.Now(), session) } else if err != nil { @@ -218,12 +218,23 @@ func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthS return } -// LoadAuthSessionForUser loads an AuthSession from the database based on the given +// LoadAuthSessionByUser loads an AuthSession from the database based on the given // realm and user ID. // Returns sql.ErrNoRows if the session isn't in the database. -func (d *ServiceDB) LoadAuthSessionForUser(realmID, userID string) (session types.AuthSession, err error) { +func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - session, err = selectAuthSessionTxn(txn, realmID, userID) + session, err = selectAuthSessionByUserTxn(txn, realmID, userID) + return err + }) + return +} + +// LoadAuthSessionByID loads an AuthSession from the database based on the given +// realm and session ID. +// Returns sql.ErrNoRows if the session isn't in the database. +func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID) return err }) return diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index 3ec01e6..64f5812 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -45,12 +45,14 @@ CREATE TABLE IF NOT EXISTS auth_realms ( ); CREATE TABLE IF NOT EXISTS auth_sessions ( + id TEXT NOT NULL, realm_id TEXT NOT NULL, user_id TEXT NOT NULL, session_json TEXT NOT NULL, time_added_ms BIGINT NOT NULL, time_updated_ms BIGINT NOT NULL, - UNIQUE(realm_id, user_id) + UNIQUE(realm_id, user_id), + UNIQUE(realm_id, id) ); ` @@ -280,8 +282,8 @@ func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error { const insertAuthSessionSQL = ` INSERT INTO auth_sessions( - realm_id, user_id, session_json, time_added_ms, time_updated_ms -) VALUES ($1, $2, $3, $4, $5) + id, realm_id, user_id, session_json, time_added_ms, time_updated_ms +) VALUES ($1, $2, $3, $4, $5, $6) ` func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -292,23 +294,56 @@ func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) t := now.UnixNano() / 1000000 _, err = txn.Exec( insertAuthSessionSQL, - session.RealmID(), session.UserID(), sessionJSON, t, t, + session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t, ) return err } -const selectAuthSessionSQL = ` -SELECT realm_type, realm_json, session_json FROM auth_sessions +const selectAuthSessionByUserSQL = ` +SELECT id, realm_type, realm_json, session_json FROM auth_sessions JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2 ` -func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { +func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { + var id string + var realmType string + var realmJSON []byte + var sessionJSON []byte + row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID) + if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil { + return nil, err + } + realm := types.CreateAuthRealm(realmID, realmType) + if realm == nil { + return nil, fmt.Errorf("Cannot create realm of type %s", realmType) + } + if err := json.Unmarshal(realmJSON, realm); err != nil { + return nil, err + } + session := realm.AuthSession(id, userID, realmID) + if session == nil { + return nil, fmt.Errorf("Cannot create session for given realm") + } + if err := json.Unmarshal(sessionJSON, session); err != nil { + return nil, err + } + return session, nil +} + +const selectAuthSessionByIDSQL = ` +SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions + JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id + WHERE auth_sessions.realm_id = $1 AND auth_sessions.id = $2 +` + +func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSession, error) { + var userID string var realmType string var realmJSON []byte var sessionJSON []byte - row := txn.QueryRow(selectAuthSessionSQL, realmID, userID) - if err := row.Scan(&realmType, &realmJSON, &sessionJSON); err != nil { + row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, id) + if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } realm := types.CreateAuthRealm(realmID, realmType) @@ -318,7 +353,7 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio if err := json.Unmarshal(realmJSON, realm); err != nil { return nil, err } - session := realm.AuthSession(userID, realmID) + session := realm.AuthSession(id, userID, realmID) if session == nil { return nil, fmt.Errorf("Cannot create session for given realm") } @@ -329,8 +364,8 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio } const updateAuthSessionSQL = ` -UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2 - WHERE realm_id=$3 AND user_id=$4 +UPDATE auth_sessions SET id=$1, session_json=$2, time_updated_ms=$3 + WHERE realm_id=$4 AND user_id=$5 ` func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { @@ -340,7 +375,7 @@ func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) } t := now.UnixNano() / 1000000 _, err = txn.Exec( - updateAuthSessionSQL, sessionJSON, t, + updateAuthSessionSQL, session.ID(), sessionJSON, t, session.RealmID(), session.UserID(), ) return err diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index b763768..2ef951a 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -20,6 +20,7 @@ type githubRealm struct { type githubSession struct { State string + id string userID string realmID string } @@ -32,6 +33,10 @@ func (s *githubSession) RealmID() string { return s.realmID } +func (s *githubSession) ID() string { + return s.id +} + func (r *githubRealm) ID() string { return r.id } @@ -55,7 +60,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int q.Set("redirect_uri", r.RedirectBaseURI+"/realms/redirects/"+r.ID()) u.RawQuery = q.Encode() session := &githubSession{ - State: state, + id: state, // key off the state for redirects userID: userID, realmID: r.ID(), } @@ -73,19 +78,29 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { code := req.URL.Query().Get("code") state := req.URL.Query().Get("state") - log.WithFields(log.Fields{ - "code": code, + logger := log.WithFields(log.Fields{ "state": state, - }).Print("GithubRealm: OnReceiveRedirect") + }) + logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect") if code == "" || state == "" { w.WriteHeader(400) w.Write([]byte("code and state are required")) return } + // load the session (we keyed off the state param) + session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state) + if err != nil { + logger.WithError(err).Print("Failed to load session") + w.WriteHeader(400) + w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause + return + } + logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user") } -func (r *githubRealm) AuthSession(userID, realmID string) types.AuthSession { +func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession { return &githubSession{ + id: id, userID: userID, realmID: realmID, } diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index f72e110..fe6c453 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -60,7 +60,7 @@ type AuthRealm interface { ID() string Type() string OnReceiveRedirect(w http.ResponseWriter, req *http.Request) - AuthSession(userID, realmID string) AuthSession + AuthSession(id, userID, realmID string) AuthSession RequestAuthSession(userID string, config json.RawMessage) interface{} } @@ -84,6 +84,7 @@ func CreateAuthRealm(realmID, realmType string) AuthRealm { // AuthSession represents a single authentication session between a user and // an auth realm. type AuthSession interface { + ID() string UserID() string RealmID() string }