You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
1.5 KiB
70 lines
1.5 KiB
package request_id
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
const AmzRequestIDHeader = "x-amz-request-id"
|
|
|
|
type contextKey struct{}
|
|
|
|
func Set(ctx context.Context, id string) context.Context {
|
|
return context.WithValue(ctx, contextKey{}, id)
|
|
}
|
|
|
|
func Get(ctx context.Context) string {
|
|
if ctx == nil {
|
|
return ""
|
|
}
|
|
id, _ := ctx.Value(contextKey{}).(string)
|
|
return id
|
|
}
|
|
|
|
func InjectToRequest(ctx context.Context, req *http.Request) {
|
|
if req != nil {
|
|
req.Header.Set(AmzRequestIDHeader, Get(ctx))
|
|
}
|
|
}
|
|
|
|
func New() string {
|
|
var buf [4]byte
|
|
rand.Read(buf[:])
|
|
return fmt.Sprintf("%X%08X", time.Now().UTC().UnixNano(), buf)
|
|
}
|
|
|
|
// GetFromRequest returns the server-generated request ID from the context.
|
|
func GetFromRequest(r *http.Request) string {
|
|
if r == nil {
|
|
return ""
|
|
}
|
|
return Get(r.Context())
|
|
}
|
|
|
|
// Ensure guarantees a server-generated request ID exists in the context.
|
|
// It always generates a new ID if one is not already present in the context,
|
|
// ignoring any client-sent x-amz-request-id header to prevent spoofing.
|
|
func Ensure(r *http.Request) (*http.Request, string) {
|
|
if r == nil {
|
|
return nil, ""
|
|
}
|
|
if id := Get(r.Context()); id != "" {
|
|
return r, id
|
|
}
|
|
id := New()
|
|
r = r.WithContext(Set(r.Context(), id))
|
|
return r, id
|
|
}
|
|
|
|
func Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r, reqID := Ensure(r)
|
|
if w.Header().Get(AmzRequestIDHeader) == "" {
|
|
w.Header().Set(AmzRequestIDHeader, reqID)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|