From 514d59e4d559569c4e467fd255445a2da93f2a30 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 2 Aug 2016 17:48:02 +0100 Subject: [PATCH] Remove Service ID from ThirdPartyAuth; query off resource instead. This de-couples ThirdPartyAuth from Services so we can do auth without having to instantiate Services. --- .../matrix-org/go-neb/database/db.go | 32 ++++++---------- .../matrix-org/go-neb/database/schema.go | 37 ++++++------------- 2 files changed, 23 insertions(+), 46 deletions(-) diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index ef8b09f..c7b84e8 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -95,11 +95,12 @@ func (d *ServiceDB) LoadServicesInRoom(serviceUserID, roomID string) (services [ return } -// LoadThirdPartyAuthsForUser loads all the third-party credentials that the given userID -// has linked to the given Service. Returns an empty list if there are no credentials. -func (d *ServiceDB) LoadThirdPartyAuthsForUser(srv types.Service, userID string) (tpas []ThirdPartyAuth, err error) { +// LoadThirdPartyAuth loads third-party credentials that the given userID +// has linked to the given resource. Returns sql.ErrNoRows if there are no +// credentials for the given resource/user combination. +func (d *ServiceDB) LoadThirdPartyAuth(resource, userID string) (tpa ThirdPartyAuth, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - tpas, err = selectThirdPartyAuthsForUserTxn(txn, srv.ServiceType(), userID) + tpa, err = selectThirdPartyAuthTxn(txn, resource, userID) if err != nil { return err } @@ -114,27 +115,18 @@ func (d *ServiceDB) LoadThirdPartyAuthsForUser(srv types.Service, userID string) // will be inserted. The previous auth is returned. func (d *ServiceDB) StoreThirdPartyAuth(tpa ThirdPartyAuth) (old ThirdPartyAuth, err error) { err = runTransaction(d.db, func(txn *sql.Tx) error { - var olds []ThirdPartyAuth - var hasOld bool - olds, err = selectThirdPartyAuthsForUserTxn(txn, tpa.ServiceType, tpa.UserID) - for _, o := range olds { - if o.UserID == tpa.UserID && o.Resource == tpa.Resource { - old = o - hasOld = true - break - } - } + old, err = selectThirdPartyAuthTxn(txn, tpa.Resource, tpa.UserID) now := time.Now().UnixNano() / 1000000 - if err != nil { - return err - } else if hasOld { - tpa.TimeUpdatedMs = now - return updateThirdPartyAuthTxn(txn, tpa) - } else { + if err == sql.ErrNoRows { tpa.TimeAddedMs = now tpa.TimeUpdatedMs = now return insertThirdPartyAuthTxn(txn, tpa) + } else if err != nil { + return err + } else { + tpa.TimeUpdatedMs = now + return updateThirdPartyAuthTxn(txn, tpa) } }) return diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index 78cf4d3..187d411 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -37,7 +37,6 @@ CREATE TABLE IF NOT EXISTS matrix_clients ( CREATE TABLE IF NOT EXISTS third_party_auth ( user_id TEXT NOT NULL, - service_type TEXT NOT NULL, resource TEXT NOT NULL, auth_json TEXT NOT NULL, time_added_ms BIGINT NOT NULL, @@ -216,16 +215,12 @@ func selectRoomServicesTxn(txn *sql.Tx, serviceUserID, roomID string) (serviceID type ThirdPartyAuth struct { // The ID of the matrix user who has authed with the third party UserID string - // The type of third party. This determines which code gets loaded to - // handle parsing of the AuthJSON. - ServiceType string // The location of the third party resource e.g. "github.com". // This is mainly relevant for decentralised services like JIRA which // may have many different locations (e.g. "matrix.org/jira") for the // same ServiceType ("jira"). Resource string - // An opaque JSON blob of stored auth data. Only the service defined in - // ServiceType knows how to parse this data. + // An opaque JSON blob of stored auth data. AuthJSON []byte // When the row was initially inserted. TimeAddedMs int64 @@ -234,36 +229,26 @@ type ThirdPartyAuth struct { } const selectThirdPartyAuthSQL = ` -SELECT resource, auth_json, time_added_ms, time_updated_ms FROM third_party_auth -WHERE user_id=$1 AND service_type=$2 +SELECT auth_json, time_added_ms, time_updated_ms FROM third_party_auth +WHERE user_id=$1 AND resource=$2 ` -func selectThirdPartyAuthsForUserTxn(txn *sql.Tx, serviceType, userID string) (auths []ThirdPartyAuth, err error) { - rows, err := txn.Query(selectThirdPartyAuthSQL, userID, serviceType) - if err != nil { - return - } - defer rows.Close() - for rows.Next() { - var tpa ThirdPartyAuth - if err = rows.Scan(&tpa.Resource, &tpa.AuthJSON, &tpa.TimeAddedMs, &tpa.TimeUpdatedMs); err != nil { - return - } - tpa.UserID = userID - tpa.ServiceType = serviceType - auths = append(auths, tpa) - } +func selectThirdPartyAuthTxn(txn *sql.Tx, resource, userID string) (tpa ThirdPartyAuth, err error) { + tpa.Resource = resource + tpa.UserID = userID + err = txn.QueryRow(selectThirdPartyAuthSQL, userID, resource).Scan( + &tpa.AuthJSON, &tpa.TimeAddedMs, &tpa.TimeUpdatedMs) return } const insertThirdPartyAuthSQL = ` INSERT INTO third_party_auth( - user_id, service_type, resource, auth_json, time_added_ms, time_updated_ms -) VALUES($1, $2, $3, $4, $5, $6) + user_id, resource, auth_json, time_added_ms, time_updated_ms +) VALUES($1, $2, $3, $4, $5) ` func insertThirdPartyAuthTxn(txn *sql.Tx, tpa ThirdPartyAuth) (err error) { - _, err = txn.Exec(insertThirdPartyAuthSQL, tpa.UserID, tpa.ServiceType, tpa.Resource, + _, err = txn.Exec(insertThirdPartyAuthSQL, tpa.UserID, tpa.Resource, tpa.AuthJSON, tpa.TimeAddedMs, tpa.TimeUpdatedMs) return }