diff --git a/src/github.com/matrix-org/go-neb/api.go b/src/github.com/matrix-org/go-neb/api.go index 8a4097a..3088c36 100644 --- a/src/github.com/matrix-org/go-neb/api.go +++ b/src/github.com/matrix-org/go-neb/api.go @@ -17,6 +17,49 @@ func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *err return &struct{}{}, nil } +type configureAuthSessionHandler struct { + db *database.ServiceDB +} + +func (h *configureAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body struct { + RealmID string + UserID string + Config json.RawMessage + } + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if body.UserID == "" || body.RealmID == "" || body.Config == nil { + return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} + } + + realm, err := h.db.LoadAuthRealm(body.RealmID) + if err != nil { + return nil, &errors.HTTPError{err, "Unknown RealmID", 400} + } + + session := realm.AuthSession(body.UserID, body.Config) + if session == nil { + return nil, &errors.HTTPError{nil, "Failed to create auth session", 500} + } + + old, err := h.db.StoreAuthSession(session) + if err != nil { + return nil, &errors.HTTPError{err, "Failed to store auth session", 500} + } + + return &struct { + RealmType string + OldSession types.AuthSession + Session types.AuthSession + }{realm.Type(), old, session}, nil +} + type configureAuthRealmHandler struct { db *database.ServiceDB } diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index 36bc44d..73d1857 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -186,6 +186,33 @@ func (d *ServiceDB) StoreAuthRealm(realm types.AuthRealm) (old types.AuthRealm, return } +// StoreAuthSession stores the given AuthSession, clobbering based on the tupe of +// user ID and realm ID. This function updates the time added/updated values. +// The previous session, if any, is returned. +func (d *ServiceDB) StoreAuthSession(session types.AuthSession) (old types.AuthSession, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + old, err = selectAuthSessionTxn(txn, session.RealmID(), session.UserID()) + if err == sql.ErrNoRows { + return insertAuthSessionTxn(txn, time.Now(), session) + } else if err != nil { + return err + } else { + return updateAuthSessionTxn(txn, time.Now(), session) + } + }) + return +} + +// LoadAuthSession loads an AuthSession from the database. +// Returns sql.ErrNoRows if the session isn't in the database. +func (d *ServiceDB) LoadAuthSession(realmID, userID string) (session types.AuthSession, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + session, err = selectAuthSessionTxn(txn, realmID, userID) + return err + }) + return +} + func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { txn, err := db.Begin() if err != nil { diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index 6f78d5b..0dab250 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -43,6 +43,15 @@ CREATE TABLE IF NOT EXISTS auth_realms ( time_updated_ms BIGINT NOT NULL, UNIQUE(realm_id) ); + +CREATE TABLE IF NOT EXISTS auth_sessions ( + realm_id TEXT NOT NULL, + user_id TEXT NOT NULL, + session_json TEXT NOT NULL, + time_added_ms BIGINT NOT NULL, + time_updated_ms BIGINT NOT NULL, + UNIQUE(realm_id, user_id) +); ` const selectServiceUserIDsSQL = ` @@ -268,3 +277,64 @@ func updateRealmTxn(txn *sql.Tx, now time.Time, realm types.AuthRealm) error { ) return err } + +const insertAuthSessionSQL = ` +INSERT INTO auth_sessions( + realm_id, user_id, session_json, time_added_ms, time_updated_ms +) VALUES ($1, $2, $3, $4, $5) +` + +func insertAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { + sessionJSON, err := json.Marshal(session) + if err != nil { + return err + } + t := now.UnixNano() / 1000000 + _, err = txn.Exec( + insertAuthSessionSQL, + session.RealmID(), session.UserID(), sessionJSON, t, t, + ) + return err +} + +const selectAuthSessionSQL = ` +SELECT realm_type, session_json FROM auth_sessions + JOIN auth_realms ON auth_sessions.realm_id = auth_realms.realm_id + WHERE auth_sessions.realm_id = $1 AND auth_sessions.user_id = $2 +` + +func selectAuthSessionTxn(txn *sql.Tx, realmID, userID string) (types.AuthSession, error) { + var realmType string + var sessionJSON []byte + row := txn.QueryRow(selectAuthSessionSQL, realmID, userID) + if err := row.Scan(&realmType, &sessionJSON); err != nil { + return nil, err + } + realm := types.CreateAuthRealm(realmID, realmType) + if realm == nil { + return nil, fmt.Errorf("Cannot create realm of type %s", realmType) + } + session := realm.AuthSession(userID, json.RawMessage(sessionJSON)) + if session == nil { + return nil, fmt.Errorf("Cannot create session for given realm") + } + return session, nil +} + +const updateAuthSessionSQL = ` +UPDATE auth_sessions SET session_json=$1, time_updated_ms=$2 + WHERE realm_id=$3 AND user_id=$4 +` + +func updateAuthSessionTxn(txn *sql.Tx, now time.Time, session types.AuthSession) error { + sessionJSON, err := json.Marshal(session) + if err != nil { + return err + } + t := now.UnixNano() / 1000000 + _, err = txn.Exec( + updateAuthSessionSQL, sessionJSON, t, + session.RealmID(), session.UserID(), + ) + return err +} diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index 4a66a7d..ae97c1b 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -33,6 +33,7 @@ func main() { http.Handle("/admin/configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients})) http.Handle("/admin/configureService", server.MakeJSONAPI(&configureServiceHandler{db: db, clients: clients})) http.Handle("/admin/configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db})) + http.Handle("/admin/configureAuthSession", server.MakeJSONAPI(&configureAuthSessionHandler{db: db})) wh := &webhookHandler{db: db, clients: clients} http.HandleFunc("/services/hooks/", wh.handle) diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index cb59f09..2bb3fcc 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -1,7 +1,9 @@ package realms import ( + "encoding/json" "github.com/matrix-org/go-neb/types" + "net/url" ) type githubRealm struct { @@ -11,6 +13,20 @@ type githubRealm struct { WebhookEndpoint string } +type githubSession struct { + URL string + userID string + realmID string +} + +func (s *githubSession) UserID() string { + return s.userID +} + +func (s *githubSession) RealmID() string { + return s.realmID +} + func (r *githubRealm) ID() string { return r.id } @@ -19,6 +35,20 @@ func (r *githubRealm) Type() string { return "github" } +func (r *githubRealm) AuthSession(userID string, config json.RawMessage) types.AuthSession { + u, _ := url.Parse("https://github.com/login/oauth/authorize") + q := u.Query() + q.Set("client_id", r.ClientID) + q.Set("client_secret", r.ClientSecret) + // TODO: state, scope + u.RawQuery = q.Encode() + return &githubSession{ + URL: u.String(), + userID: userID, + realmID: r.ID(), + } +} + func init() { types.RegisterAuthRealm(func(realmID string) types.AuthRealm { return &githubRealm{id: realmID} diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index 87c519a..97edddc 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -1,6 +1,7 @@ package types import ( + "encoding/json" "errors" "github.com/matrix-org/go-neb/matrix" "github.com/matrix-org/go-neb/plugin" @@ -58,6 +59,7 @@ func CreateService(serviceID, serviceType string) Service { type AuthRealm interface { ID() string Type() string + AuthSession(userID string, config json.RawMessage) AuthSession } var realmsByType = map[string]func(string) AuthRealm{} @@ -76,3 +78,10 @@ func CreateAuthRealm(realmID, realmType string) AuthRealm { } return f(realmID) } + +// AuthSession represents a single authentication session between a user and +// an auth realm. +type AuthSession interface { + UserID() string + RealmID() string +}