From 514aab3c61e981693f9a05f50b54cd7043ac13a7 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 12 Aug 2016 09:42:20 +0100 Subject: [PATCH] Create realms with JSON by default --- src/github.com/matrix-org/go-neb/api.go | 10 ++---- .../matrix-org/go-neb/database/schema.go | 32 +++++-------------- .../matrix-org/go-neb/types/types.go | 12 ++++--- 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/src/github.com/matrix-org/go-neb/api.go b/src/github.com/matrix-org/go-neb/api.go index 6904a85..baa6875 100644 --- a/src/github.com/matrix-org/go-neb/api.go +++ b/src/github.com/matrix-org/go-neb/api.go @@ -96,16 +96,12 @@ func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interf return nil, &errors.HTTPError{nil, `Must supply a "ID", a "Type" and a "Config"`, 400} } - realm := types.CreateAuthRealm(body.ID, body.Type) - if realm == nil { - return nil, &errors.HTTPError{nil, "Unknown realm type", 400} - } - - if err := json.Unmarshal(body.Config, realm); err != nil { + realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) + if err != nil { return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} } - if err := realm.Register(); err != nil { + if err = realm.Register(); err != nil { return nil, &errors.HTTPError{err, "Error registering auth realm", 400} } diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index c2046d9..8cd58c7 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -252,14 +252,7 @@ func selectRealmTxn(txn *sql.Tx, realmID string) (types.AuthRealm, error) { if err := row.Scan(&realmType, &realmJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { - return nil, err - } - return realm, nil + return types.CreateAuthRealm(realmID, realmType, realmJSON) } const selectRealmsByTypeSQL = ` @@ -273,17 +266,14 @@ func selectRealmsByTypeTxn(txn *sql.Tx, realmType string) (realms []types.AuthRe } defer rows.Close() for rows.Next() { + var realm types.AuthRealm var realmID string var realmJSON []byte if err = rows.Scan(&realmID, &realmJSON); err != nil { return } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - err = fmt.Errorf("Cannot create realm %s of type %s", realmID, realmType) - return - } - if err = json.Unmarshal(realmJSON, realm); err != nil { + realm, err = types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return } realms = append(realms, realm) @@ -343,11 +333,8 @@ func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.Auth if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { + realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return nil, err } session := realm.AuthSession(id, userID, realmID) @@ -375,11 +362,8 @@ func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSessio if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { + realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return nil, err } session := realm.AuthSession(id, userID, realmID) diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index da9a00e..49ca0ea 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -95,14 +95,18 @@ func RegisterAuthRealm(factory func(string, string) AuthRealm) { } // CreateAuthRealm creates an AuthRealm of the given type and realm ID. -// Returns nil if the realm couldn't be created. -func CreateAuthRealm(realmID, realmType string) AuthRealm { +// Returns an error if the realm couldn't be created or the JSON cannot be unmarshalled. +func CreateAuthRealm(realmID, realmType string, realmJSON []byte) (AuthRealm, error) { f := realmsByType[realmType] if f == nil { - return nil + return nil, errors.New("Unknown realm type: " + realmType) } redirectURL := baseURL + "realms/redirects/" + realmID - return f(realmID, redirectURL) + r := f(realmID, redirectURL) + if err := json.Unmarshal(realmJSON, r); err != nil { + return nil, err + } + return r, nil } // AuthSession represents a single authentication session between a user and