Browse Source

Add ID field to AuthSessions so redirects can lookup based off this key

Required for Github OAuth redirect requests and just is generally useful to
have. Add UNIQUE constraints on realm/user and realm/id to prevent multiple
users getting the same ID.
kegan/github-auth
Kegan Dougal 9 years ago
parent
commit
71be73a721
  1. 19
      src/github.com/matrix-org/go-neb/database/db.go
  2. 61
      src/github.com/matrix-org/go-neb/database/schema.go
  3. 25
      src/github.com/matrix-org/go-neb/realms/github/github.go
  4. 3
      src/github.com/matrix-org/go-neb/types/types.go

19
src/github.com/matrix-org/go-neb/database/db.go

@ -206,7 +206,7 @@ func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm,
// The previous session, if any, is returned. // The previous session, if any, is returned.
func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) { func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error { err = runTransaction(d.db, func(txn *sql.Tx) error {
old, err = selectAuthSessionTxn(txn, session.RealmID(), session.UserID())
old, err = selectAuthSessionByUserTxn(txn, session.RealmID(), session.UserID())
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return insertAuthSessionTxn(txn, time.Now(), session) return insertAuthSessionTxn(txn, time.Now(), session)
} else if err != nil { } else if err != nil {
@ -218,12 +218,23 @@ func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthS
return return
} }
// LoadAuthSessionForUser loads an AuthSession from the database based on the given
// LoadAuthSessionByUser loads an AuthSession from the database based on the given
// realm and user ID. // realm and user ID.
// Returns sql.ErrNoRows if the session isn't in the database. // Returns sql.ErrNoRows if the session isn't in the database.
func (d *ServiceDB) LoadAuthSessionForUser(realmID, userID string) (session types.AuthSession, err error) {
func (d *ServiceDB) LoadAuthSessionByUser(realmID, userID string) (session types.AuthSession, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error { err = runTransaction(d.db, func(txn *sql.Tx) error {
session, err = selectAuthSessionTxn(txn, realmID, userID)
session, err = selectAuthSessionByUserTxn(txn, realmID, userID)
return err
})
return
}
// LoadAuthSessionByID loads an AuthSession from the database based on the given
// realm and session ID.
// Returns sql.ErrNoRows if the session isn't in the database.
func (d *ServiceDB) LoadAuthSessionByID(realmID, sessionID string) (session types.AuthSession, err error) {
err = runTransaction(d.db, func(txn *sql.Tx) error {
session, err = selectAuthSessionByIDTxn(txn, realmID, sessionID)
return err return err
}) })
return return

61
src/github.com/matrix-org/go-neb/database/schema.go

@ -45,12 +45,14 @@ CREATE TABLE IF NOT EXISTS auth_realms (
); );
CREATE TABLE IF NOT EXISTS auth_sessions ( CREATE TABLE IF NOT EXISTS auth_sessions (
id TEXT NOT NULL,
realm_id TEXT NOT NULL, realm_id TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
session_json TEXT NOT NULL, session_json TEXT NOT NULL,
time_added_ms BIGINT NOT NULL, time_added_ms BIGINT NOT NULL,
time_updated_ms BIGINT NOT NULL, time_updated_ms BIGINT NOT NULL,
UNIQUE(realm_id, user_id)
UNIQUE(realm_id, user_id),
UNIQUE(realm_id, id)
); );
` `
@ -280,8 +282,8 @@ func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error {
const insertAuthSessionSQL = ` const insertAuthSessionSQL = `
INSERT INTO auth_sessions( INSERT INTO auth_sessions(
realm_id, user_id, session_json, time_added_ms, time_updated_ms
) VALUES ($1, $2, $3, $4, $5)
id, realm_id, user_id, session_json, time_added_ms, time_updated_ms
) VALUES ($1, $2, $3, $4, $5, $6)
` `
func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
@ -292,23 +294,56 @@ func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession)
t := now.UnixNano() / 1000000 t := now.UnixNano() / 1000000
_, err = txn.Exec( _, err = txn.Exec(
insertAuthSessionSQL, insertAuthSessionSQL,
session.RealmID(), session.UserID(), sessionJSON, t, t,
session.ID(), session.RealmID(), session.UserID(), sessionJSON, t, t,
) )
return err return err
} }
const selectAuthSessionSQL = `
SELECT realm_type, realm_json, session_json FROM auth_sessions
const selectAuthSessionByUserSQL = `
SELECT id, realm_type, realm_json, session_json FROM auth_sessions
JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2 WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2
` `
func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) {
func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) {
var id string
var realmType string
var realmJSON []byte
var sessionJSON []byte
row := txn.QueryRow(selectAuthSessionByUserSQL, realmID, userID)
if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil {
return nil, err
}
realm := types.CreateAuthRealm(realmID, realmType)
if realm == nil {
return nil, fmt.Errorf("Cannot create realm of type %s", realmType)
}
if err := json.Unmarshal(realmJSON, realm); err != nil {
return nil, err
}
session := realm.AuthSession(id, userID, realmID)
if session == nil {
return nil, fmt.Errorf("Cannot create session for given realm")
}
if err := json.Unmarshal(sessionJSON, session); err != nil {
return nil, err
}
return session, nil
}
const selectAuthSessionByIDSQL = `
SELECT user_id, realm_type, realm_json, session_json FROM auth_sessions
JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id
WHERE auth_sessions.realm_id = $1 AND auth_sessions.id = $2
`
func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSession, error) {
var userID string
var realmType string var realmType string
var realmJSON []byte var realmJSON []byte
var sessionJSON []byte var sessionJSON []byte
row := txn.QueryRow(selectAuthSessionSQL, realmID, userID)
if err := row.Scan(&realmType, &realmJSON, &sessionJSON); err != nil {
row := txn.QueryRow(selectAuthSessionByIDSQL, realmID, id)
if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil {
return nil, err return nil, err
} }
realm := types.CreateAuthRealm(realmID, realmType) realm := types.CreateAuthRealm(realmID, realmType)
@ -318,7 +353,7 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio
if err := json.Unmarshal(realmJSON, realm); err != nil { if err := json.Unmarshal(realmJSON, realm); err != nil {
return nil, err return nil, err
} }
session := realm.AuthSession(userID, realmID)
session := realm.AuthSession(id, userID, realmID)
if session == nil { if session == nil {
return nil, fmt.Errorf("Cannot create session for given realm") return nil, fmt.Errorf("Cannot create session for given realm")
} }
@ -329,8 +364,8 @@ func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSessio
} }
const updateAuthSessionSQL = ` const updateAuthSessionSQL = `
UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2
WHERE realm_id=$3 AND user_id=$4
UPDATE auth_sessions SET id=$1, session_json=$2, time_updated_ms=$3
WHERE realm_id=$4 AND user_id=$5
` `
func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error {
@ -340,7 +375,7 @@ func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession)
} }
t := now.UnixNano() / 1000000 t := now.UnixNano() / 1000000
_, err = txn.Exec( _, err = txn.Exec(
updateAuthSessionSQL, sessionJSON, t,
updateAuthSessionSQL, session.ID(), sessionJSON, t,
session.RealmID(), session.UserID(), session.RealmID(), session.UserID(),
) )
return err return err

25
src/github.com/matrix-org/go-neb/realms/github/github.go

@ -20,6 +20,7 @@ type githubRealm struct {
type githubSession struct { type githubSession struct {
State string State string
id string
userID string userID string
realmID string realmID string
} }
@ -32,6 +33,10 @@ func (s *githubSession) RealmID() string {
return s.realmID return s.realmID
} }
func (s *githubSession) ID() string {
return s.id
}
func (r *githubRealm) ID() string { func (r *githubRealm) ID() string {
return r.id return r.id
} }
@ -55,7 +60,7 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int
q.Set("redirect_uri", r.RedirectBaseURI+"/realms/redirects/"+r.ID()) q.Set("redirect_uri", r.RedirectBaseURI+"/realms/redirects/"+r.ID())
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
session := &githubSession{ session := &githubSession{
State: state,
id: state, // key off the state for redirects
userID: userID, userID: userID,
realmID: r.ID(), realmID: r.ID(),
} }
@ -73,19 +78,29 @@ func (r *githubRealm) RequestAuthSession(userID string, req json.RawMessage) int
func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { func (r *githubRealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) {
code := req.URL.Query().Get("code") code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state") state := req.URL.Query().Get("state")
log.WithFields(log.Fields{
"code": code,
logger := log.WithFields(log.Fields{
"state": state, "state": state,
}).Print("GithubRealm: OnReceiveRedirect")
})
logger.WithField("code", code).Print("GithubRealm: OnReceiveRedirect")
if code == "" || state == "" { if code == "" || state == "" {
w.WriteHeader(400) w.WriteHeader(400)
w.Write([]byte("code and state are required")) w.Write([]byte("code and state are required"))
return return
} }
// load the session (we keyed off the state param)
session, err := database.GetServiceDB().LoadAuthSessionByID(r.ID(), state)
if err != nil {
logger.WithError(err).Print("Failed to load session")
w.WriteHeader(400)
w.Write([]byte("Provided ?state= param is not recognised.")) // most likely cause
return
}
logger.WithField("user_id", session.UserID()).Print("Mapped redirect to user")
} }
func (r *githubRealm) AuthSession(userID, realmID string) types.AuthSession {
func (r *githubRealm) AuthSession(id, userID, realmID string) types.AuthSession {
return &githubSession{ return &githubSession{
id: id,
userID: userID, userID: userID,
realmID: realmID, realmID: realmID,
} }

3
src/github.com/matrix-org/go-neb/types/types.go

@ -60,7 +60,7 @@ type AuthRealm interface {
ID() string ID() string
Type() string Type() string
OnReceiveRedirect(w http.ResponseWriter, req *http.Request) OnReceiveRedirect(w http.ResponseWriter, req *http.Request)
AuthSession(userID, realmID string) AuthSession
AuthSession(id, userID, realmID string) AuthSession
RequestAuthSession(userID string, config json.RawMessage) interface{} RequestAuthSession(userID string, config json.RawMessage) interface{}
} }
@ -84,6 +84,7 @@ func CreateAuthRealm(realmID, realmType string) AuthRealm {
// AuthSession represents a single authentication session between a user and // AuthSession represents a single authentication session between a user and
// an auth realm. // an auth realm.
type AuthSession interface { type AuthSession interface {
ID() string
UserID() string UserID() string
RealmID() string RealmID() string
} }
Loading…
Cancel
Save