diff --git a/src/github.com/matrix-org/go-neb/api.go b/src/github.com/matrix-org/go-neb/api.go index 6904a85..146a820 100644 --- a/src/github.com/matrix-org/go-neb/api.go +++ b/src/github.com/matrix-org/go-neb/api.go @@ -96,16 +96,12 @@ func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interf return nil, &errors.HTTPError{nil, `Must supply a "ID", a "Type" and a "Config"`, 400} } - realm := types.CreateAuthRealm(body.ID, body.Type) - if realm == nil { - return nil, &errors.HTTPError{nil, "Unknown realm type", 400} - } - - if err := json.Unmarshal(body.Config, realm); err != nil { + realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) + if err != nil { return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} } - if err := realm.Register(); err != nil { + if err = realm.Register(); err != nil { return nil, &errors.HTTPError{err, "Error registering auth realm", 400} } @@ -200,16 +196,12 @@ func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interfac return nil, &errors.HTTPError{nil, `Must supply a "ID", a "Type" and a "Config"`, 400} } - service := types.CreateService(body.ID, body.Type) - if service == nil { - return nil, &errors.HTTPError{nil, "Unknown service type", 400} - } - - if err := json.Unmarshal(body.Config, service); err != nil { + service, err := types.CreateService(body.ID, body.Type, body.Config) + if err != nil { return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} } - err := service.Register() + err = service.Register() if err != nil { return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} } diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index c2046d9..0eff3f1 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -135,14 +135,7 @@ func selectServiceTxn(txn *sql.Tx, serviceID string) (types.Service, error) { if err := row.Scan(&serviceType, &serviceJSON); err != nil { return nil, err } - service := types.CreateService(serviceID, serviceType) - if service == nil { - return nil, fmt.Errorf("Cannot create services of type %s", serviceType) - } - if err := json.Unmarshal(serviceJSON, service); err != nil { - return nil, err - } - return service, nil + return types.CreateService(serviceID, serviceType, serviceJSON) } const updateServiceSQL = ` @@ -252,14 +245,7 @@ func selectRealmTxn(txn *sql.Tx, realmID string) (types.AuthRealm, error) { if err := row.Scan(&realmType, &realmJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { - return nil, err - } - return realm, nil + return types.CreateAuthRealm(realmID, realmType, realmJSON) } const selectRealmsByTypeSQL = ` @@ -273,17 +259,14 @@ func selectRealmsByTypeTxn(txn *sql.Tx, realmType string) (realms []types.AuthRe } defer rows.Close() for rows.Next() { + var realm types.AuthRealm var realmID string var realmJSON []byte if err = rows.Scan(&realmID, &realmJSON); err != nil { return } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - err = fmt.Errorf("Cannot create realm %s of type %s", realmID, realmType) - return - } - if err = json.Unmarshal(realmJSON, realm); err != nil { + realm, err = types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return } realms = append(realms, realm) @@ -343,11 +326,8 @@ func selectAuthSessionByUserTxn(txn *sql.Tx, realmID, userID string) (types.Auth if err := row.Scan(&id, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { + realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return nil, err } session := realm.AuthSession(id, userID, realmID) @@ -375,11 +355,8 @@ func selectAuthSessionByIDTxn(txn *sql.Tx, realmID, id string) (types.AuthSessio if err := row.Scan(&userID, &realmType, &realmJSON, &sessionJSON); err != nil { return nil, err } - realm := types.CreateAuthRealm(realmID, realmType) - if realm == nil { - return nil, fmt.Errorf("Cannot create realm of type %s", realmType) - } - if err := json.Unmarshal(realmJSON, realm); err != nil { + realm, err := types.CreateAuthRealm(realmID, realmType, realmJSON) + if err != nil { return nil, err } session := realm.AuthSession(id, userID, realmID) diff --git a/src/github.com/matrix-org/go-neb/realms/github/github.go b/src/github.com/matrix-org/go-neb/realms/github/github.go index d6e3ef0..7d624e4 100644 --- a/src/github.com/matrix-org/go-neb/realms/github/github.go +++ b/src/github.com/matrix-org/go-neb/realms/github/github.go @@ -53,6 +53,10 @@ func (r *githubRealm) Type() string { return "github" } +func (r *githubRealm) Init() error { + return nil +} + func (r *githubRealm) Register() error { return nil } diff --git a/src/github.com/matrix-org/go-neb/realms/jira/jira.go b/src/github.com/matrix-org/go-neb/realms/jira/jira.go index c28f9cf..3ef777f 100644 --- a/src/github.com/matrix-org/go-neb/realms/jira/jira.go +++ b/src/github.com/matrix-org/go-neb/realms/jira/jira.go @@ -71,6 +71,22 @@ func (r *JIRARealm) Type() string { return "jira" } +// Init initialises the private key for this JIRA realm. +func (r *JIRARealm) Init() error { + if err := r.parsePrivateKey(); err != nil { + log.WithError(err).Print("Failed to parse private key") + return err + } + // Parse the messy input URL into a canonicalised form. + ju, err := urls.ParseJIRAURL(r.JIRAEndpoint) + if err != nil { + log.WithError(err).Print("Failed to parse JIRA endpoint") + return err + } + r.JIRAEndpoint = ju.Base + return nil +} + // Register is called when this realm is being created from an external entity func (r *JIRARealm) Register() error { if r.ConsumerName == "" || r.ConsumerKey == "" || r.ConsumerSecret == "" || r.PrivateKeyPEM == "" { @@ -80,10 +96,6 @@ func (r *JIRARealm) Register() error { return errors.New("JIRAEndpoint must be specified") } - if err := r.ensureInited(); err != nil { - return err - } - // Check to see if JIRA endpoint is valid by pinging an endpoint cli, err := r.JIRAClient("", true) if err != nil { @@ -107,10 +119,6 @@ func (r *JIRARealm) Register() error { // RequestAuthSession is called by a user wishing to auth with this JIRA realm func (r *JIRARealm) RequestAuthSession(userID string, req json.RawMessage) interface{} { logger := log.WithField("jira_url", r.JIRAEndpoint) - if err := r.ensureInited(); err != nil { - logger.WithError(err).Print("Failed to init realm") - return nil - } authConfig := r.oauth1Config(r.JIRAEndpoint) reqToken, reqSec, err := authConfig.RequestToken() if err != nil { @@ -143,10 +151,6 @@ func (r *JIRARealm) RequestAuthSession(userID string, req json.RawMessage) inter // OnReceiveRedirect is called when JIRA installations redirect back to NEB func (r *JIRARealm) OnReceiveRedirect(w http.ResponseWriter, req *http.Request) { logger := log.WithField("jira_url", r.JIRAEndpoint) - if err := r.ensureInited(); err != nil { - failWith(logger, w, 500, "Failed to initialise realm", err) - return - } requestToken, verifier, err := oauth1.ParseAuthorizationCallback(req) if err != nil { @@ -203,9 +207,6 @@ func (r *JIRARealm) AuthSession(id, userID, realmID string) types.AuthSession { // unauthenticated client will be used, which may not be able to see the complete list // of projects. func (r *JIRARealm) ProjectKeyExists(userID, projectKey string) (bool, error) { - if err := r.ensureInited(); err != nil { - return false, err - } cli, err := r.JIRAClient(userID, true) if err != nil { return false, err @@ -274,21 +275,6 @@ func (r *JIRARealm) JIRAClient(userID string, allowUnauth bool) (*jira.Client, e return jira.NewClient(httpClient, r.JIRAEndpoint) } -func (r *JIRARealm) ensureInited() error { - if err := r.parsePrivateKey(); err != nil { - log.WithError(err).Print("Failed to parse private key") - return err - } - // Parse the messy input URL into a canonicalised form. - ju, err := urls.ParseJIRAURL(r.JIRAEndpoint) - if err != nil { - log.WithError(err).Print("Failed to parse JIRA endpoint") - return err - } - r.JIRAEndpoint = ju.Base - return nil -} - func (r *JIRARealm) parsePrivateKey() error { if r.privateKey != nil { return nil diff --git a/src/github.com/matrix-org/go-neb/types/types.go b/src/github.com/matrix-org/go-neb/types/types.go index da9a00e..14febc5 100644 --- a/src/github.com/matrix-org/go-neb/types/types.go +++ b/src/github.com/matrix-org/go-neb/types/types.go @@ -66,14 +66,18 @@ func RegisterService(factory func(string, string) Service) { } // CreateService creates a Service of the given type and serviceID. -// Returns nil if the Service couldn't be created. -func CreateService(serviceID, serviceType string) Service { +// Returns an error if the Service couldn't be created. +func CreateService(serviceID, serviceType string, serviceJSON []byte) (Service, error) { f := servicesByType[serviceType] if f == nil { - return nil + return nil, errors.New("Unknown service type: " + serviceType) } webhookEndpointURL := baseURL + "services/hooks/" + serviceID - return f(serviceID, webhookEndpointURL) + service := f(serviceID, webhookEndpointURL) + if err := json.Unmarshal(serviceJSON, service); err != nil { + return nil, err + } + return service, nil } // AuthRealm represents a place where a user can authenticate themselves. @@ -81,6 +85,7 @@ func CreateService(serviceID, serviceType string) Service { type AuthRealm interface { ID() string Type() string + Init() error Register() error OnReceiveRedirect(w http.ResponseWriter, req *http.Request) AuthSession(id, userID, realmID string) AuthSession @@ -95,14 +100,21 @@ func RegisterAuthRealm(factory func(string, string) AuthRealm) { } // CreateAuthRealm creates an AuthRealm of the given type and realm ID. -// Returns nil if the realm couldn't be created. -func CreateAuthRealm(realmID, realmType string) AuthRealm { +// Returns an error if the realm couldn't be created or the JSON cannot be unmarshalled. +func CreateAuthRealm(realmID, realmType string, realmJSON []byte) (AuthRealm, error) { f := realmsByType[realmType] if f == nil { - return nil + return nil, errors.New("Unknown realm type: " + realmType) } redirectURL := baseURL + "realms/redirects/" + realmID - return f(realmID, redirectURL) + r := f(realmID, redirectURL) + if err := json.Unmarshal(realmJSON, r); err != nil { + return nil, err + } + if err := r.Init(); err != nil { + return nil, err + } + return r, nil } // AuthSession represents a single authentication session between a user and