diff --git a/src/github.com/matrix-org/go-neb/database/db.go b/src/github.com/matrix-org/go-neb/database/db.go index ece6d23..ef8b09f 100644 --- a/src/github.com/matrix-org/go-neb/database/db.go +++ b/src/github.com/matrix-org/go-neb/database/db.go @@ -95,6 +95,51 @@ func (d *ServiceDB) LoadServicesInRoom(serviceUserID, roomID string) (services [ return } +// LoadThirdPartyAuthsForUser loads all the third-party credentials that the given userID +// has linked to the given Service. Returns an empty list if there are no credentials. +func (d *ServiceDB) LoadThirdPartyAuthsForUser(srv types.Service, userID string) (tpas []ThirdPartyAuth, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + tpas, err = selectThirdPartyAuthsForUserTxn(txn, srv.ServiceType(), userID) + if err != nil { + return err + } + return nil + }) + return +} + +// StoreThirdPartyAuth stores the ThirdPartyAuth for the given Service. Updates the +// time added/updated values. +// If the auth already exists then it will be updated, otherwise a new auth +// will be inserted. The previous auth is returned. +func (d *ServiceDB) StoreThirdPartyAuth(tpa ThirdPartyAuth) (old ThirdPartyAuth, err error) { + err = runTransaction(d.db, func(txn *sql.Tx) error { + var olds []ThirdPartyAuth + var hasOld bool + olds, err = selectThirdPartyAuthsForUserTxn(txn, tpa.ServiceType, tpa.UserID) + for _, o := range olds { + if o.UserID == tpa.UserID && o.Resource == tpa.Resource { + old = o + hasOld = true + break + } + } + now := time.Now().UnixNano() / 1000000 + + if err != nil { + return err + } else if hasOld { + tpa.TimeUpdatedMs = now + return updateThirdPartyAuthTxn(txn, tpa) + } else { + tpa.TimeAddedMs = now + tpa.TimeUpdatedMs = now + return insertThirdPartyAuthTxn(txn, tpa) + } + }) + return +} + // StoreService stores a service into the database either by inserting a new // service or updating an existing service. Returns the old service if there // was one. diff --git a/src/github.com/matrix-org/go-neb/database/schema.go b/src/github.com/matrix-org/go-neb/database/schema.go index f85a6dd..78cf4d3 100644 --- a/src/github.com/matrix-org/go-neb/database/schema.go +++ b/src/github.com/matrix-org/go-neb/database/schema.go @@ -212,6 +212,7 @@ func selectRoomServicesTxn(txn *sql.Tx, serviceUserID, roomID string) (serviceID return } +// ThirdPartyAuth represents a third_party_auth data row. type ThirdPartyAuth struct { // The ID of the matrix user who has authed with the third party UserID string @@ -227,9 +228,9 @@ type ThirdPartyAuth struct { // ServiceType knows how to parse this data. AuthJSON []byte // When the row was initially inserted. - TimeAddedMs int + TimeAddedMs int64 // When the row was last updated. - TimeUpdatedMs int + TimeUpdatedMs int64 } const selectThirdPartyAuthSQL = ` @@ -237,8 +238,8 @@ SELECT resource, auth_json, time_added_ms, time_updated_ms FROM third_party_auth WHERE user_id=$1 AND service_type=$2 ` -func selectThirdPartyAuthsForUserTxn(txn *sql.Tx, service types.Service, userID string) (auths []ThirdPartyAuth, err error) { - rows, err := txn.Query(selectThirdPartyAuthSQL, userID, service.ServiceType()) +func selectThirdPartyAuthsForUserTxn(txn *sql.Tx, serviceType, userID string) (auths []ThirdPartyAuth, err error) { + rows, err := txn.Query(selectThirdPartyAuthSQL, userID, serviceType) if err != nil { return } @@ -249,7 +250,7 @@ func selectThirdPartyAuthsForUserTxn(txn *sql.Tx, service types.Service, userID return } tpa.UserID = userID - tpa.ServiceType = service.ServiceType() + tpa.ServiceType = serviceType auths = append(auths, tpa) } return