|
@ -2,6 +2,7 @@ package main |
|
|
|
|
|
|
|
|
import ( |
|
|
import ( |
|
|
"database/sql" |
|
|
"database/sql" |
|
|
|
|
|
"encoding/base64" |
|
|
"encoding/json" |
|
|
"encoding/json" |
|
|
log "github.com/Sirupsen/logrus" |
|
|
log "github.com/Sirupsen/logrus" |
|
|
"github.com/matrix-org/go-neb/clients" |
|
|
"github.com/matrix-org/go-neb/clients" |
|
@ -62,8 +63,18 @@ type realmRedirectHandler struct { |
|
|
|
|
|
|
|
|
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
|
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
|
segments := strings.Split(req.URL.Path, "/") |
|
|
segments := strings.Split(req.URL.Path, "/") |
|
|
// last path segment is the realm ID which we will pass the incoming request to
|
|
|
|
|
|
realmID := segments[len(segments)-1] |
|
|
|
|
|
|
|
|
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|
|
|
|
|
base64realmID := segments[len(segments)-1] |
|
|
|
|
|
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|
|
|
|
|
realmID := string(bytesRealmID) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|
|
|
|
|
"Not a b64 encoded string", |
|
|
|
|
|
) |
|
|
|
|
|
w.WriteHeader(400) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
realm, err := rh.db.LoadAuthRealm(realmID) |
|
|
realm, err := rh.db.LoadAuthRealm(realmID) |
|
|
if err != nil { |
|
|
if err != nil { |
|
|
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|
|
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|
@ -126,8 +137,19 @@ type webhookHandler struct { |
|
|
|
|
|
|
|
|
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
|
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
|
segments := strings.Split(req.URL.Path, "/") |
|
|
segments := strings.Split(req.URL.Path, "/") |
|
|
// last path segment is the service ID which we will pass the incoming request to
|
|
|
|
|
|
srvID := segments[len(segments)-1] |
|
|
|
|
|
|
|
|
// last path segment is the service ID which we will pass the incoming request to,
|
|
|
|
|
|
// but we've base64d it.
|
|
|
|
|
|
base64srvID := segments[len(segments)-1] |
|
|
|
|
|
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|
|
|
|
|
srvID := string(bytesSrvID) |
|
|
|
|
|
if err != nil { |
|
|
|
|
|
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|
|
|
|
|
"Not a b64 encoded string", |
|
|
|
|
|
) |
|
|
|
|
|
w.WriteHeader(400) |
|
|
|
|
|
return |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
service, err := wh.db.LoadService(srvID) |
|
|
service, err := wh.db.LoadService(srvID) |
|
|
if err != nil { |
|
|
if err != nil { |
|
|
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|
|
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|
|