mirror of https://github.com/matrix-org/go-neb.git
Kegan Dougal
9 years ago
71 changed files with 7198 additions and 0 deletions
-
27vendor/manifest
-
156vendor/src/golang.org/x/net/context/context.go
-
577vendor/src/golang.org/x/net/context/context_test.go
-
74vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp.go
-
28vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go
-
147vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go
-
79vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go
-
105vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_test.go
-
72vendor/src/golang.org/x/net/context/go17.go
-
300vendor/src/golang.org/x/net/context/pre_go17.go
-
26vendor/src/golang.org/x/net/context/withtimeout_test.go
-
3vendor/src/golang.org/x/oauth2/AUTHORS
-
31vendor/src/golang.org/x/oauth2/CONTRIBUTING.md
-
3vendor/src/golang.org/x/oauth2/CONTRIBUTORS
-
27vendor/src/golang.org/x/oauth2/LICENSE
-
64vendor/src/golang.org/x/oauth2/README.md
-
16vendor/src/golang.org/x/oauth2/bitbucket/bitbucket.go
-
25vendor/src/golang.org/x/oauth2/client_appengine.go
-
112vendor/src/golang.org/x/oauth2/clientcredentials/clientcredentials.go
-
96vendor/src/golang.org/x/oauth2/clientcredentials/clientcredentials_test.go
-
45vendor/src/golang.org/x/oauth2/example_test.go
-
16vendor/src/golang.org/x/oauth2/facebook/facebook.go
-
16vendor/src/golang.org/x/oauth2/fitbit/fitbit.go
-
16vendor/src/golang.org/x/oauth2/github/github.go
-
86vendor/src/golang.org/x/oauth2/google/appengine.go
-
13vendor/src/golang.org/x/oauth2/google/appengine_hook.go
-
14vendor/src/golang.org/x/oauth2/google/appenginevm_hook.go
-
155vendor/src/golang.org/x/oauth2/google/default.go
-
150vendor/src/golang.org/x/oauth2/google/example_test.go
-
149vendor/src/golang.org/x/oauth2/google/google.go
-
97vendor/src/golang.org/x/oauth2/google/google_test.go
-
74vendor/src/golang.org/x/oauth2/google/jwt.go
-
91vendor/src/golang.org/x/oauth2/google/jwt_test.go
-
168vendor/src/golang.org/x/oauth2/google/sdk.go
-
46vendor/src/golang.org/x/oauth2/google/sdk_test.go
-
122vendor/src/golang.org/x/oauth2/google/testdata/gcloud/credentials
-
2vendor/src/golang.org/x/oauth2/google/testdata/gcloud/properties
-
60vendor/src/golang.org/x/oauth2/hipchat/hipchat.go
-
76vendor/src/golang.org/x/oauth2/internal/oauth2.go
-
62vendor/src/golang.org/x/oauth2/internal/oauth2_test.go
-
225vendor/src/golang.org/x/oauth2/internal/token.go
-
36vendor/src/golang.org/x/oauth2/internal/token_test.go
-
69vendor/src/golang.org/x/oauth2/internal/transport.go
-
38vendor/src/golang.org/x/oauth2/internal/transport_test.go
-
174vendor/src/golang.org/x/oauth2/jws/jws.go
-
46vendor/src/golang.org/x/oauth2/jws/jws_test.go
-
31vendor/src/golang.org/x/oauth2/jwt/example_test.go
-
157vendor/src/golang.org/x/oauth2/jwt/jwt.go
-
134vendor/src/golang.org/x/oauth2/jwt/jwt_test.go
-
16vendor/src/golang.org/x/oauth2/linkedin/linkedin.go
-
16vendor/src/golang.org/x/oauth2/microsoft/microsoft.go
-
337vendor/src/golang.org/x/oauth2/oauth2.go
-
458vendor/src/golang.org/x/oauth2/oauth2_test.go
-
16vendor/src/golang.org/x/oauth2/odnoklassniki/odnoklassniki.go
-
22vendor/src/golang.org/x/oauth2/paypal/paypal.go
-
16vendor/src/golang.org/x/oauth2/slack/slack.go
-
158vendor/src/golang.org/x/oauth2/token.go
-
72vendor/src/golang.org/x/oauth2/token_test.go
-
132vendor/src/golang.org/x/oauth2/transport.go
-
108vendor/src/golang.org/x/oauth2/transport_test.go
-
16vendor/src/golang.org/x/oauth2/vk/vk.go
-
438vendor/src/google.golang.org/cloud/compute/metadata/metadata.go
-
48vendor/src/google.golang.org/cloud/compute/metadata/metadata_test.go
-
257vendor/src/google.golang.org/cloud/internal/bundler/bundler.go
-
209vendor/src/google.golang.org/cloud/internal/bundler/bundler_test.go
-
128vendor/src/google.golang.org/cloud/internal/cloud.go
-
60vendor/src/google.golang.org/cloud/internal/testutil/context.go
-
186vendor/src/google.golang.org/cloud/internal/testutil/iterators.go
-
73vendor/src/google.golang.org/cloud/internal/testutil/server.go
-
35vendor/src/google.golang.org/cloud/internal/testutil/server_test.go
-
61vendor/src/google.golang.org/cloud/internal/transport/dial.go
@ -0,0 +1,156 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package context defines the Context type, which carries deadlines,
|
||||
|
// cancelation signals, and other request-scoped values across API boundaries
|
||||
|
// and between processes.
|
||||
|
//
|
||||
|
// Incoming requests to a server should create a Context, and outgoing calls to
|
||||
|
// servers should accept a Context. The chain of function calls between must
|
||||
|
// propagate the Context, optionally replacing it with a modified copy created
|
||||
|
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
|
||||
|
//
|
||||
|
// Programs that use Contexts should follow these rules to keep interfaces
|
||||
|
// consistent across packages and enable static analysis tools to check context
|
||||
|
// propagation:
|
||||
|
//
|
||||
|
// Do not store Contexts inside a struct type; instead, pass a Context
|
||||
|
// explicitly to each function that needs it. The Context should be the first
|
||||
|
// parameter, typically named ctx:
|
||||
|
//
|
||||
|
// func DoSomething(ctx context.Context, arg Arg) error {
|
||||
|
// // ... use ctx ...
|
||||
|
// }
|
||||
|
//
|
||||
|
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
|
||||
|
// if you are unsure about which Context to use.
|
||||
|
//
|
||||
|
// Use context Values only for request-scoped data that transits processes and
|
||||
|
// APIs, not for passing optional parameters to functions.
|
||||
|
//
|
||||
|
// The same Context may be passed to functions running in different goroutines;
|
||||
|
// Contexts are safe for simultaneous use by multiple goroutines.
|
||||
|
//
|
||||
|
// See http://blog.golang.org/context for example code for a server that uses
|
||||
|
// Contexts.
|
||||
|
package context // import "golang.org/x/net/context"
|
||||
|
|
||||
|
import "time" |
||||
|
|
||||
|
// A Context carries a deadline, a cancelation signal, and other values across
|
||||
|
// API boundaries.
|
||||
|
//
|
||||
|
// Context's methods may be called by multiple goroutines simultaneously.
|
||||
|
type Context interface { |
||||
|
// Deadline returns the time when work done on behalf of this context
|
||||
|
// should be canceled. Deadline returns ok==false when no deadline is
|
||||
|
// set. Successive calls to Deadline return the same results.
|
||||
|
Deadline() (deadline time.Time, ok bool) |
||||
|
|
||||
|
// Done returns a channel that's closed when work done on behalf of this
|
||||
|
// context should be canceled. Done may return nil if this context can
|
||||
|
// never be canceled. Successive calls to Done return the same value.
|
||||
|
//
|
||||
|
// WithCancel arranges for Done to be closed when cancel is called;
|
||||
|
// WithDeadline arranges for Done to be closed when the deadline
|
||||
|
// expires; WithTimeout arranges for Done to be closed when the timeout
|
||||
|
// elapses.
|
||||
|
//
|
||||
|
// Done is provided for use in select statements:
|
||||
|
//
|
||||
|
// // Stream generates values with DoSomething and sends them to out
|
||||
|
// // until DoSomething returns an error or ctx.Done is closed.
|
||||
|
// func Stream(ctx context.Context, out chan<- Value) error {
|
||||
|
// for {
|
||||
|
// v, err := DoSomething(ctx)
|
||||
|
// if err != nil {
|
||||
|
// return err
|
||||
|
// }
|
||||
|
// select {
|
||||
|
// case <-ctx.Done():
|
||||
|
// return ctx.Err()
|
||||
|
// case out <- v:
|
||||
|
// }
|
||||
|
// }
|
||||
|
// }
|
||||
|
//
|
||||
|
// See http://blog.golang.org/pipelines for more examples of how to use
|
||||
|
// a Done channel for cancelation.
|
||||
|
Done() <-chan struct{} |
||||
|
|
||||
|
// Err returns a non-nil error value after Done is closed. Err returns
|
||||
|
// Canceled if the context was canceled or DeadlineExceeded if the
|
||||
|
// context's deadline passed. No other values for Err are defined.
|
||||
|
// After Done is closed, successive calls to Err return the same value.
|
||||
|
Err() error |
||||
|
|
||||
|
// Value returns the value associated with this context for key, or nil
|
||||
|
// if no value is associated with key. Successive calls to Value with
|
||||
|
// the same key returns the same result.
|
||||
|
//
|
||||
|
// Use context values only for request-scoped data that transits
|
||||
|
// processes and API boundaries, not for passing optional parameters to
|
||||
|
// functions.
|
||||
|
//
|
||||
|
// A key identifies a specific value in a Context. Functions that wish
|
||||
|
// to store values in Context typically allocate a key in a global
|
||||
|
// variable then use that key as the argument to context.WithValue and
|
||||
|
// Context.Value. A key can be any type that supports equality;
|
||||
|
// packages should define keys as an unexported type to avoid
|
||||
|
// collisions.
|
||||
|
//
|
||||
|
// Packages that define a Context key should provide type-safe accessors
|
||||
|
// for the values stores using that key:
|
||||
|
//
|
||||
|
// // Package user defines a User type that's stored in Contexts.
|
||||
|
// package user
|
||||
|
//
|
||||
|
// import "golang.org/x/net/context"
|
||||
|
//
|
||||
|
// // User is the type of value stored in the Contexts.
|
||||
|
// type User struct {...}
|
||||
|
//
|
||||
|
// // key is an unexported type for keys defined in this package.
|
||||
|
// // This prevents collisions with keys defined in other packages.
|
||||
|
// type key int
|
||||
|
//
|
||||
|
// // userKey is the key for user.User values in Contexts. It is
|
||||
|
// // unexported; clients use user.NewContext and user.FromContext
|
||||
|
// // instead of using this key directly.
|
||||
|
// var userKey key = 0
|
||||
|
//
|
||||
|
// // NewContext returns a new Context that carries value u.
|
||||
|
// func NewContext(ctx context.Context, u *User) context.Context {
|
||||
|
// return context.WithValue(ctx, userKey, u)
|
||||
|
// }
|
||||
|
//
|
||||
|
// // FromContext returns the User value stored in ctx, if any.
|
||||
|
// func FromContext(ctx context.Context) (*User, bool) {
|
||||
|
// u, ok := ctx.Value(userKey).(*User)
|
||||
|
// return u, ok
|
||||
|
// }
|
||||
|
Value(key interface{}) interface{} |
||||
|
} |
||||
|
|
||||
|
// Background returns a non-nil, empty Context. It is never canceled, has no
|
||||
|
// values, and has no deadline. It is typically used by the main function,
|
||||
|
// initialization, and tests, and as the top-level Context for incoming
|
||||
|
// requests.
|
||||
|
func Background() Context { |
||||
|
return background |
||||
|
} |
||||
|
|
||||
|
// TODO returns a non-nil, empty Context. Code should use context.TODO when
|
||||
|
// it's unclear which Context to use or it is not yet available (because the
|
||||
|
// surrounding function has not yet been extended to accept a Context
|
||||
|
// parameter). TODO is recognized by static analysis tools that determine
|
||||
|
// whether Contexts are propagated correctly in a program.
|
||||
|
func TODO() Context { |
||||
|
return todo |
||||
|
} |
||||
|
|
||||
|
// A CancelFunc tells an operation to abandon its work.
|
||||
|
// A CancelFunc does not wait for the work to stop.
|
||||
|
// After the first call, subsequent calls to a CancelFunc do nothing.
|
||||
|
type CancelFunc func() |
@ -0,0 +1,577 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !go1.7
|
||||
|
|
||||
|
package context |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"math/rand" |
||||
|
"runtime" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// otherContext is a Context that's not one of the types defined in context.go.
|
||||
|
// This lets us test code paths that differ based on the underlying type of the
|
||||
|
// Context.
|
||||
|
type otherContext struct { |
||||
|
Context |
||||
|
} |
||||
|
|
||||
|
func TestBackground(t *testing.T) { |
||||
|
c := Background() |
||||
|
if c == nil { |
||||
|
t.Fatalf("Background returned nil") |
||||
|
} |
||||
|
select { |
||||
|
case x := <-c.Done(): |
||||
|
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
if got, want := fmt.Sprint(c), "context.Background"; got != want { |
||||
|
t.Errorf("Background().String() = %q want %q", got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTODO(t *testing.T) { |
||||
|
c := TODO() |
||||
|
if c == nil { |
||||
|
t.Fatalf("TODO returned nil") |
||||
|
} |
||||
|
select { |
||||
|
case x := <-c.Done(): |
||||
|
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
if got, want := fmt.Sprint(c), "context.TODO"; got != want { |
||||
|
t.Errorf("TODO().String() = %q want %q", got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestWithCancel(t *testing.T) { |
||||
|
c1, cancel := WithCancel(Background()) |
||||
|
|
||||
|
if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { |
||||
|
t.Errorf("c1.String() = %q want %q", got, want) |
||||
|
} |
||||
|
|
||||
|
o := otherContext{c1} |
||||
|
c2, _ := WithCancel(o) |
||||
|
contexts := []Context{c1, o, c2} |
||||
|
|
||||
|
for i, c := range contexts { |
||||
|
if d := c.Done(); d == nil { |
||||
|
t.Errorf("c[%d].Done() == %v want non-nil", i, d) |
||||
|
} |
||||
|
if e := c.Err(); e != nil { |
||||
|
t.Errorf("c[%d].Err() == %v want nil", i, e) |
||||
|
} |
||||
|
|
||||
|
select { |
||||
|
case x := <-c.Done(): |
||||
|
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
cancel() |
||||
|
time.Sleep(100 * time.Millisecond) // let cancelation propagate
|
||||
|
|
||||
|
for i, c := range contexts { |
||||
|
select { |
||||
|
case <-c.Done(): |
||||
|
default: |
||||
|
t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) |
||||
|
} |
||||
|
if e := c.Err(); e != Canceled { |
||||
|
t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestParentFinishesChild(t *testing.T) { |
||||
|
// Context tree:
|
||||
|
// parent -> cancelChild
|
||||
|
// parent -> valueChild -> timerChild
|
||||
|
parent, cancel := WithCancel(Background()) |
||||
|
cancelChild, stop := WithCancel(parent) |
||||
|
defer stop() |
||||
|
valueChild := WithValue(parent, "key", "value") |
||||
|
timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) |
||||
|
defer stop() |
||||
|
|
||||
|
select { |
||||
|
case x := <-parent.Done(): |
||||
|
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
||||
|
case x := <-cancelChild.Done(): |
||||
|
t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) |
||||
|
case x := <-timerChild.Done(): |
||||
|
t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) |
||||
|
case x := <-valueChild.Done(): |
||||
|
t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
|
||||
|
// The parent's children should contain the two cancelable children.
|
||||
|
pc := parent.(*cancelCtx) |
||||
|
cc := cancelChild.(*cancelCtx) |
||||
|
tc := timerChild.(*timerCtx) |
||||
|
pc.mu.Lock() |
||||
|
if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { |
||||
|
t.Errorf("bad linkage: pc.children = %v, want %v and %v", |
||||
|
pc.children, cc, tc) |
||||
|
} |
||||
|
pc.mu.Unlock() |
||||
|
|
||||
|
if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { |
||||
|
t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) |
||||
|
} |
||||
|
if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { |
||||
|
t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) |
||||
|
} |
||||
|
|
||||
|
cancel() |
||||
|
|
||||
|
pc.mu.Lock() |
||||
|
if len(pc.children) != 0 { |
||||
|
t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) |
||||
|
} |
||||
|
pc.mu.Unlock() |
||||
|
|
||||
|
// parent and children should all be finished.
|
||||
|
check := func(ctx Context, name string) { |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
default: |
||||
|
t.Errorf("<-%s.Done() blocked, but shouldn't have", name) |
||||
|
} |
||||
|
if e := ctx.Err(); e != Canceled { |
||||
|
t.Errorf("%s.Err() == %v want %v", name, e, Canceled) |
||||
|
} |
||||
|
} |
||||
|
check(parent, "parent") |
||||
|
check(cancelChild, "cancelChild") |
||||
|
check(valueChild, "valueChild") |
||||
|
check(timerChild, "timerChild") |
||||
|
|
||||
|
// WithCancel should return a canceled context on a canceled parent.
|
||||
|
precanceledChild := WithValue(parent, "key", "value") |
||||
|
select { |
||||
|
case <-precanceledChild.Done(): |
||||
|
default: |
||||
|
t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") |
||||
|
} |
||||
|
if e := precanceledChild.Err(); e != Canceled { |
||||
|
t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestChildFinishesFirst(t *testing.T) { |
||||
|
cancelable, stop := WithCancel(Background()) |
||||
|
defer stop() |
||||
|
for _, parent := range []Context{Background(), cancelable} { |
||||
|
child, cancel := WithCancel(parent) |
||||
|
|
||||
|
select { |
||||
|
case x := <-parent.Done(): |
||||
|
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
||||
|
case x := <-child.Done(): |
||||
|
t.Errorf("<-child.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
|
||||
|
cc := child.(*cancelCtx) |
||||
|
pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
|
||||
|
if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { |
||||
|
t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) |
||||
|
} |
||||
|
|
||||
|
if pcok { |
||||
|
pc.mu.Lock() |
||||
|
if len(pc.children) != 1 || !pc.children[cc] { |
||||
|
t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) |
||||
|
} |
||||
|
pc.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
cancel() |
||||
|
|
||||
|
if pcok { |
||||
|
pc.mu.Lock() |
||||
|
if len(pc.children) != 0 { |
||||
|
t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) |
||||
|
} |
||||
|
pc.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// child should be finished.
|
||||
|
select { |
||||
|
case <-child.Done(): |
||||
|
default: |
||||
|
t.Errorf("<-child.Done() blocked, but shouldn't have") |
||||
|
} |
||||
|
if e := child.Err(); e != Canceled { |
||||
|
t.Errorf("child.Err() == %v want %v", e, Canceled) |
||||
|
} |
||||
|
|
||||
|
// parent should not be finished.
|
||||
|
select { |
||||
|
case x := <-parent.Done(): |
||||
|
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
||||
|
default: |
||||
|
} |
||||
|
if e := parent.Err(); e != nil { |
||||
|
t.Errorf("parent.Err() == %v want nil", e) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testDeadline(c Context, wait time.Duration, t *testing.T) { |
||||
|
select { |
||||
|
case <-time.After(wait): |
||||
|
t.Fatalf("context should have timed out") |
||||
|
case <-c.Done(): |
||||
|
} |
||||
|
if e := c.Err(); e != DeadlineExceeded { |
||||
|
t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestDeadline(t *testing.T) { |
||||
|
c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
||||
|
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { |
||||
|
t.Errorf("c.String() = %q want prefix %q", got, prefix) |
||||
|
} |
||||
|
testDeadline(c, 200*time.Millisecond, t) |
||||
|
|
||||
|
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
||||
|
o := otherContext{c} |
||||
|
testDeadline(o, 200*time.Millisecond, t) |
||||
|
|
||||
|
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
||||
|
o = otherContext{c} |
||||
|
c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond)) |
||||
|
testDeadline(c, 200*time.Millisecond, t) |
||||
|
} |
||||
|
|
||||
|
func TestTimeout(t *testing.T) { |
||||
|
c, _ := WithTimeout(Background(), 100*time.Millisecond) |
||||
|
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { |
||||
|
t.Errorf("c.String() = %q want prefix %q", got, prefix) |
||||
|
} |
||||
|
testDeadline(c, 200*time.Millisecond, t) |
||||
|
|
||||
|
c, _ = WithTimeout(Background(), 100*time.Millisecond) |
||||
|
o := otherContext{c} |
||||
|
testDeadline(o, 200*time.Millisecond, t) |
||||
|
|
||||
|
c, _ = WithTimeout(Background(), 100*time.Millisecond) |
||||
|
o = otherContext{c} |
||||
|
c, _ = WithTimeout(o, 300*time.Millisecond) |
||||
|
testDeadline(c, 200*time.Millisecond, t) |
||||
|
} |
||||
|
|
||||
|
func TestCanceledTimeout(t *testing.T) { |
||||
|
c, _ := WithTimeout(Background(), 200*time.Millisecond) |
||||
|
o := otherContext{c} |
||||
|
c, cancel := WithTimeout(o, 400*time.Millisecond) |
||||
|
cancel() |
||||
|
time.Sleep(100 * time.Millisecond) // let cancelation propagate
|
||||
|
select { |
||||
|
case <-c.Done(): |
||||
|
default: |
||||
|
t.Errorf("<-c.Done() blocked, but shouldn't have") |
||||
|
} |
||||
|
if e := c.Err(); e != Canceled { |
||||
|
t.Errorf("c.Err() == %v want %v", e, Canceled) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
type key1 int |
||||
|
type key2 int |
||||
|
|
||||
|
var k1 = key1(1) |
||||
|
var k2 = key2(1) // same int as k1, different type
|
||||
|
var k3 = key2(3) // same type as k2, different int
|
||||
|
|
||||
|
func TestValues(t *testing.T) { |
||||
|
check := func(c Context, nm, v1, v2, v3 string) { |
||||
|
if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { |
||||
|
t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) |
||||
|
} |
||||
|
if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { |
||||
|
t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) |
||||
|
} |
||||
|
if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { |
||||
|
t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
c0 := Background() |
||||
|
check(c0, "c0", "", "", "") |
||||
|
|
||||
|
c1 := WithValue(Background(), k1, "c1k1") |
||||
|
check(c1, "c1", "c1k1", "", "") |
||||
|
|
||||
|
if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { |
||||
|
t.Errorf("c.String() = %q want %q", got, want) |
||||
|
} |
||||
|
|
||||
|
c2 := WithValue(c1, k2, "c2k2") |
||||
|
check(c2, "c2", "c1k1", "c2k2", "") |
||||
|
|
||||
|
c3 := WithValue(c2, k3, "c3k3") |
||||
|
check(c3, "c2", "c1k1", "c2k2", "c3k3") |
||||
|
|
||||
|
c4 := WithValue(c3, k1, nil) |
||||
|
check(c4, "c4", "", "c2k2", "c3k3") |
||||
|
|
||||
|
o0 := otherContext{Background()} |
||||
|
check(o0, "o0", "", "", "") |
||||
|
|
||||
|
o1 := otherContext{WithValue(Background(), k1, "c1k1")} |
||||
|
check(o1, "o1", "c1k1", "", "") |
||||
|
|
||||
|
o2 := WithValue(o1, k2, "o2k2") |
||||
|
check(o2, "o2", "c1k1", "o2k2", "") |
||||
|
|
||||
|
o3 := otherContext{c4} |
||||
|
check(o3, "o3", "", "c2k2", "c3k3") |
||||
|
|
||||
|
o4 := WithValue(o3, k3, nil) |
||||
|
check(o4, "o4", "", "c2k2", "") |
||||
|
} |
||||
|
|
||||
|
func TestAllocs(t *testing.T) { |
||||
|
bg := Background() |
||||
|
for _, test := range []struct { |
||||
|
desc string |
||||
|
f func() |
||||
|
limit float64 |
||||
|
gccgoLimit float64 |
||||
|
}{ |
||||
|
{ |
||||
|
desc: "Background()", |
||||
|
f: func() { Background() }, |
||||
|
limit: 0, |
||||
|
gccgoLimit: 0, |
||||
|
}, |
||||
|
{ |
||||
|
desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), |
||||
|
f: func() { |
||||
|
c := WithValue(bg, k1, nil) |
||||
|
c.Value(k1) |
||||
|
}, |
||||
|
limit: 3, |
||||
|
gccgoLimit: 3, |
||||
|
}, |
||||
|
{ |
||||
|
desc: "WithTimeout(bg, 15*time.Millisecond)", |
||||
|
f: func() { |
||||
|
c, _ := WithTimeout(bg, 15*time.Millisecond) |
||||
|
<-c.Done() |
||||
|
}, |
||||
|
limit: 8, |
||||
|
gccgoLimit: 16, |
||||
|
}, |
||||
|
{ |
||||
|
desc: "WithCancel(bg)", |
||||
|
f: func() { |
||||
|
c, cancel := WithCancel(bg) |
||||
|
cancel() |
||||
|
<-c.Done() |
||||
|
}, |
||||
|
limit: 5, |
||||
|
gccgoLimit: 8, |
||||
|
}, |
||||
|
{ |
||||
|
desc: "WithTimeout(bg, 100*time.Millisecond)", |
||||
|
f: func() { |
||||
|
c, cancel := WithTimeout(bg, 100*time.Millisecond) |
||||
|
cancel() |
||||
|
<-c.Done() |
||||
|
}, |
||||
|
limit: 8, |
||||
|
gccgoLimit: 25, |
||||
|
}, |
||||
|
} { |
||||
|
limit := test.limit |
||||
|
if runtime.Compiler == "gccgo" { |
||||
|
// gccgo does not yet do escape analysis.
|
||||
|
// TODO(iant): Remove this when gccgo does do escape analysis.
|
||||
|
limit = test.gccgoLimit |
||||
|
} |
||||
|
if n := testing.AllocsPerRun(100, test.f); n > limit { |
||||
|
t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestSimultaneousCancels(t *testing.T) { |
||||
|
root, cancel := WithCancel(Background()) |
||||
|
m := map[Context]CancelFunc{root: cancel} |
||||
|
q := []Context{root} |
||||
|
// Create a tree of contexts.
|
||||
|
for len(q) != 0 && len(m) < 100 { |
||||
|
parent := q[0] |
||||
|
q = q[1:] |
||||
|
for i := 0; i < 4; i++ { |
||||
|
ctx, cancel := WithCancel(parent) |
||||
|
m[ctx] = cancel |
||||
|
q = append(q, ctx) |
||||
|
} |
||||
|
} |
||||
|
// Start all the cancels in a random order.
|
||||
|
var wg sync.WaitGroup |
||||
|
wg.Add(len(m)) |
||||
|
for _, cancel := range m { |
||||
|
go func(cancel CancelFunc) { |
||||
|
cancel() |
||||
|
wg.Done() |
||||
|
}(cancel) |
||||
|
} |
||||
|
// Wait on all the contexts in a random order.
|
||||
|
for ctx := range m { |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
case <-time.After(1 * time.Second): |
||||
|
buf := make([]byte, 10<<10) |
||||
|
n := runtime.Stack(buf, true) |
||||
|
t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) |
||||
|
} |
||||
|
} |
||||
|
// Wait for all the cancel functions to return.
|
||||
|
done := make(chan struct{}) |
||||
|
go func() { |
||||
|
wg.Wait() |
||||
|
close(done) |
||||
|
}() |
||||
|
select { |
||||
|
case <-done: |
||||
|
case <-time.After(1 * time.Second): |
||||
|
buf := make([]byte, 10<<10) |
||||
|
n := runtime.Stack(buf, true) |
||||
|
t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestInterlockedCancels(t *testing.T) { |
||||
|
parent, cancelParent := WithCancel(Background()) |
||||
|
child, cancelChild := WithCancel(parent) |
||||
|
go func() { |
||||
|
parent.Done() |
||||
|
cancelChild() |
||||
|
}() |
||||
|
cancelParent() |
||||
|
select { |
||||
|
case <-child.Done(): |
||||
|
case <-time.After(1 * time.Second): |
||||
|
buf := make([]byte, 10<<10) |
||||
|
n := runtime.Stack(buf, true) |
||||
|
t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestLayersCancel(t *testing.T) { |
||||
|
testLayers(t, time.Now().UnixNano(), false) |
||||
|
} |
||||
|
|
||||
|
func TestLayersTimeout(t *testing.T) { |
||||
|
testLayers(t, time.Now().UnixNano(), true) |
||||
|
} |
||||
|
|
||||
|
func testLayers(t *testing.T, seed int64, testTimeout bool) { |
||||
|
rand.Seed(seed) |
||||
|
errorf := func(format string, a ...interface{}) { |
||||
|
t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) |
||||
|
} |
||||
|
const ( |
||||
|
timeout = 200 * time.Millisecond |
||||
|
minLayers = 30 |
||||
|
) |
||||
|
type value int |
||||
|
var ( |
||||
|
vals []*value |
||||
|
cancels []CancelFunc |
||||
|
numTimers int |
||||
|
ctx = Background() |
||||
|
) |
||||
|
for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { |
||||
|
switch rand.Intn(3) { |
||||
|
case 0: |
||||
|
v := new(value) |
||||
|
ctx = WithValue(ctx, v, v) |
||||
|
vals = append(vals, v) |
||||
|
case 1: |
||||
|
var cancel CancelFunc |
||||
|
ctx, cancel = WithCancel(ctx) |
||||
|
cancels = append(cancels, cancel) |
||||
|
case 2: |
||||
|
var cancel CancelFunc |
||||
|
ctx, cancel = WithTimeout(ctx, timeout) |
||||
|
cancels = append(cancels, cancel) |
||||
|
numTimers++ |
||||
|
} |
||||
|
} |
||||
|
checkValues := func(when string) { |
||||
|
for _, key := range vals { |
||||
|
if val := ctx.Value(key).(*value); key != val { |
||||
|
errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
errorf("ctx should not be canceled yet") |
||||
|
default: |
||||
|
} |
||||
|
if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { |
||||
|
t.Errorf("ctx.String() = %q want prefix %q", s, prefix) |
||||
|
} |
||||
|
t.Log(ctx) |
||||
|
checkValues("before cancel") |
||||
|
if testTimeout { |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
case <-time.After(timeout + 100*time.Millisecond): |
||||
|
errorf("ctx should have timed out") |
||||
|
} |
||||
|
checkValues("after timeout") |
||||
|
} else { |
||||
|
cancel := cancels[rand.Intn(len(cancels))] |
||||
|
cancel() |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
default: |
||||
|
errorf("ctx should be canceled") |
||||
|
} |
||||
|
checkValues("after cancel") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestCancelRemoves(t *testing.T) { |
||||
|
checkChildren := func(when string, ctx Context, want int) { |
||||
|
if got := len(ctx.(*cancelCtx).children); got != want { |
||||
|
t.Errorf("%s: context has %d children, want %d", when, got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ctx, _ := WithCancel(Background()) |
||||
|
checkChildren("after creation", ctx, 0) |
||||
|
_, cancel := WithCancel(ctx) |
||||
|
checkChildren("with WithCancel child ", ctx, 1) |
||||
|
cancel() |
||||
|
checkChildren("after cancelling WithCancel child", ctx, 0) |
||||
|
|
||||
|
ctx, _ = WithCancel(Background()) |
||||
|
checkChildren("after creation", ctx, 0) |
||||
|
_, cancel = WithTimeout(ctx, 60*time.Minute) |
||||
|
checkChildren("with WithTimeout child ", ctx, 1) |
||||
|
cancel() |
||||
|
checkChildren("after cancelling WithTimeout child", ctx, 0) |
||||
|
} |
@ -0,0 +1,74 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build go1.7
|
||||
|
|
||||
|
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
|
||||
|
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
// Do sends an HTTP request with the provided http.Client and returns
|
||||
|
// an HTTP response.
|
||||
|
//
|
||||
|
// If the client is nil, http.DefaultClient is used.
|
||||
|
//
|
||||
|
// The provided ctx must be non-nil. If it is canceled or times out,
|
||||
|
// ctx.Err() will be returned.
|
||||
|
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { |
||||
|
if client == nil { |
||||
|
client = http.DefaultClient |
||||
|
} |
||||
|
resp, err := client.Do(req.WithContext(ctx)) |
||||
|
// If we got an error, and the context has been canceled,
|
||||
|
// the context's error is probably more useful.
|
||||
|
if err != nil { |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
err = ctx.Err() |
||||
|
default: |
||||
|
} |
||||
|
} |
||||
|
return resp, err |
||||
|
} |
||||
|
|
||||
|
// Get issues a GET request via the Do function.
|
||||
|
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("GET", url, nil) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// Head issues a HEAD request via the Do function.
|
||||
|
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("HEAD", url, nil) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// Post issues a POST request via the Do function.
|
||||
|
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("POST", url, body) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
req.Header.Set("Content-Type", bodyType) |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// PostForm issues a POST request via the Do function.
|
||||
|
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { |
||||
|
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) |
||||
|
} |
@ -0,0 +1,28 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !plan9,go1.7
|
||||
|
|
||||
|
package ctxhttp |
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
|
||||
|
"context" |
||||
|
) |
||||
|
|
||||
|
func TestGo17Context(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
io.WriteString(w, "ok") |
||||
|
})) |
||||
|
ctx := context.Background() |
||||
|
resp, err := Get(ctx, http.DefaultClient, ts.URL) |
||||
|
if resp == nil || err != nil { |
||||
|
t.Fatalf("error received from client: %v %v", err, resp) |
||||
|
} |
||||
|
resp.Body.Close() |
||||
|
} |
@ -0,0 +1,147 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !go1.7
|
||||
|
|
||||
|
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
func nop() {} |
||||
|
|
||||
|
var ( |
||||
|
testHookContextDoneBeforeHeaders = nop |
||||
|
testHookDoReturned = nop |
||||
|
testHookDidBodyClose = nop |
||||
|
) |
||||
|
|
||||
|
// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
|
||||
|
// If the client is nil, http.DefaultClient is used.
|
||||
|
// If the context is canceled or times out, ctx.Err() will be returned.
|
||||
|
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { |
||||
|
if client == nil { |
||||
|
client = http.DefaultClient |
||||
|
} |
||||
|
|
||||
|
// TODO(djd): Respect any existing value of req.Cancel.
|
||||
|
cancel := make(chan struct{}) |
||||
|
req.Cancel = cancel |
||||
|
|
||||
|
type responseAndError struct { |
||||
|
resp *http.Response |
||||
|
err error |
||||
|
} |
||||
|
result := make(chan responseAndError, 1) |
||||
|
|
||||
|
// Make local copies of test hooks closed over by goroutines below.
|
||||
|
// Prevents data races in tests.
|
||||
|
testHookDoReturned := testHookDoReturned |
||||
|
testHookDidBodyClose := testHookDidBodyClose |
||||
|
|
||||
|
go func() { |
||||
|
resp, err := client.Do(req) |
||||
|
testHookDoReturned() |
||||
|
result <- responseAndError{resp, err} |
||||
|
}() |
||||
|
|
||||
|
var resp *http.Response |
||||
|
|
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
testHookContextDoneBeforeHeaders() |
||||
|
close(cancel) |
||||
|
// Clean up after the goroutine calling client.Do:
|
||||
|
go func() { |
||||
|
if r := <-result; r.resp != nil { |
||||
|
testHookDidBodyClose() |
||||
|
r.resp.Body.Close() |
||||
|
} |
||||
|
}() |
||||
|
return nil, ctx.Err() |
||||
|
case r := <-result: |
||||
|
var err error |
||||
|
resp, err = r.resp, r.err |
||||
|
if err != nil { |
||||
|
return resp, err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
c := make(chan struct{}) |
||||
|
go func() { |
||||
|
select { |
||||
|
case <-ctx.Done(): |
||||
|
close(cancel) |
||||
|
case <-c: |
||||
|
// The response's Body is closed.
|
||||
|
} |
||||
|
}() |
||||
|
resp.Body = ¬ifyingReader{resp.Body, c} |
||||
|
|
||||
|
return resp, nil |
||||
|
} |
||||
|
|
||||
|
// Get issues a GET request via the Do function.
|
||||
|
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("GET", url, nil) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// Head issues a HEAD request via the Do function.
|
||||
|
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("HEAD", url, nil) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// Post issues a POST request via the Do function.
|
||||
|
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { |
||||
|
req, err := http.NewRequest("POST", url, body) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
req.Header.Set("Content-Type", bodyType) |
||||
|
return Do(ctx, client, req) |
||||
|
} |
||||
|
|
||||
|
// PostForm issues a POST request via the Do function.
|
||||
|
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { |
||||
|
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) |
||||
|
} |
||||
|
|
||||
|
// notifyingReader is an io.ReadCloser that closes the notify channel after
|
||||
|
// Close is called or a Read fails on the underlying ReadCloser.
|
||||
|
type notifyingReader struct { |
||||
|
io.ReadCloser |
||||
|
notify chan<- struct{} |
||||
|
} |
||||
|
|
||||
|
func (r *notifyingReader) Read(p []byte) (int, error) { |
||||
|
n, err := r.ReadCloser.Read(p) |
||||
|
if err != nil && r.notify != nil { |
||||
|
close(r.notify) |
||||
|
r.notify = nil |
||||
|
} |
||||
|
return n, err |
||||
|
} |
||||
|
|
||||
|
func (r *notifyingReader) Close() error { |
||||
|
err := r.ReadCloser.Close() |
||||
|
if r.notify != nil { |
||||
|
close(r.notify) |
||||
|
r.notify = nil |
||||
|
} |
||||
|
return err |
||||
|
} |
@ -0,0 +1,79 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !plan9,!go1.7
|
||||
|
|
||||
|
package ctxhttp |
||||
|
|
||||
|
import ( |
||||
|
"net" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"sync" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
// golang.org/issue/14065
|
||||
|
func TestClosesResponseBodyOnCancel(t *testing.T) { |
||||
|
defer func() { testHookContextDoneBeforeHeaders = nop }() |
||||
|
defer func() { testHookDoReturned = nop }() |
||||
|
defer func() { testHookDidBodyClose = nop }() |
||||
|
|
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
|
||||
|
// closed when Do enters select case <-ctx.Done()
|
||||
|
enteredDonePath := make(chan struct{}) |
||||
|
|
||||
|
testHookContextDoneBeforeHeaders = func() { |
||||
|
close(enteredDonePath) |
||||
|
} |
||||
|
|
||||
|
testHookDoReturned = func() { |
||||
|
// We now have the result (the Flush'd headers) at least,
|
||||
|
// so we can cancel the request.
|
||||
|
cancel() |
||||
|
|
||||
|
// But block the client.Do goroutine from sending
|
||||
|
// until Do enters into the <-ctx.Done() path, since
|
||||
|
// otherwise if both channels are readable, select
|
||||
|
// picks a random one.
|
||||
|
<-enteredDonePath |
||||
|
} |
||||
|
|
||||
|
sawBodyClose := make(chan struct{}) |
||||
|
testHookDidBodyClose = func() { close(sawBodyClose) } |
||||
|
|
||||
|
tr := &http.Transport{} |
||||
|
defer tr.CloseIdleConnections() |
||||
|
c := &http.Client{Transport: tr} |
||||
|
req, _ := http.NewRequest("GET", ts.URL, nil) |
||||
|
_, doErr := Do(ctx, c, req) |
||||
|
|
||||
|
select { |
||||
|
case <-sawBodyClose: |
||||
|
case <-time.After(5 * time.Second): |
||||
|
t.Fatal("timeout waiting for body to close") |
||||
|
} |
||||
|
|
||||
|
if doErr != ctx.Err() { |
||||
|
t.Errorf("Do error = %v; want %v", doErr, ctx.Err()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
type noteCloseConn struct { |
||||
|
net.Conn |
||||
|
onceClose sync.Once |
||||
|
closefn func() |
||||
|
} |
||||
|
|
||||
|
func (c *noteCloseConn) Close() error { |
||||
|
c.onceClose.Do(c.closefn) |
||||
|
return c.Conn.Close() |
||||
|
} |
@ -0,0 +1,105 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !plan9
|
||||
|
|
||||
|
package ctxhttp |
||||
|
|
||||
|
import ( |
||||
|
"io" |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
requestDuration = 100 * time.Millisecond |
||||
|
requestBody = "ok" |
||||
|
) |
||||
|
|
||||
|
func okHandler(w http.ResponseWriter, r *http.Request) { |
||||
|
time.Sleep(requestDuration) |
||||
|
io.WriteString(w, requestBody) |
||||
|
} |
||||
|
|
||||
|
func TestNoTimeout(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(okHandler)) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
ctx := context.Background() |
||||
|
res, err := Get(ctx, nil, ts.URL) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
defer res.Body.Close() |
||||
|
slurp, err := ioutil.ReadAll(res.Body) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if string(slurp) != requestBody { |
||||
|
t.Errorf("body = %q; want %q", slurp, requestBody) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestCancelBeforeHeaders(t *testing.T) { |
||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
|
||||
|
blockServer := make(chan struct{}) |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
cancel() |
||||
|
<-blockServer |
||||
|
io.WriteString(w, requestBody) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
defer close(blockServer) |
||||
|
|
||||
|
res, err := Get(ctx, nil, ts.URL) |
||||
|
if err == nil { |
||||
|
res.Body.Close() |
||||
|
t.Fatal("Get returned unexpected nil error") |
||||
|
} |
||||
|
if err != context.Canceled { |
||||
|
t.Errorf("err = %v; want %v", err, context.Canceled) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestCancelAfterHangingRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.WriteHeader(http.StatusOK) |
||||
|
w.(http.Flusher).Flush() |
||||
|
<-w.(http.CloseNotifier).CloseNotify() |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
resp, err := Get(ctx, nil, ts.URL) |
||||
|
if err != nil { |
||||
|
t.Fatalf("unexpected error in Get: %v", err) |
||||
|
} |
||||
|
|
||||
|
// Cancel befer reading the body.
|
||||
|
// Reading Request.Body should fail, since the request was
|
||||
|
// canceled before anything was written.
|
||||
|
cancel() |
||||
|
|
||||
|
done := make(chan struct{}) |
||||
|
|
||||
|
go func() { |
||||
|
b, err := ioutil.ReadAll(resp.Body) |
||||
|
if len(b) != 0 || err == nil { |
||||
|
t.Errorf(`Read got (%q, %v); want ("", error)`, b, err) |
||||
|
} |
||||
|
close(done) |
||||
|
}() |
||||
|
|
||||
|
select { |
||||
|
case <-time.After(1 * time.Second): |
||||
|
t.Errorf("Test timed out") |
||||
|
case <-done: |
||||
|
} |
||||
|
} |
@ -0,0 +1,72 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build go1.7
|
||||
|
|
||||
|
package context |
||||
|
|
||||
|
import ( |
||||
|
"context" // standard library's context, as of Go 1.7
|
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
todo = context.TODO() |
||||
|
background = context.Background() |
||||
|
) |
||||
|
|
||||
|
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
|
var Canceled = context.Canceled |
||||
|
|
||||
|
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
|
// deadline passes.
|
||||
|
var DeadlineExceeded = context.DeadlineExceeded |
||||
|
|
||||
|
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
|
// context's Done channel is closed when the returned cancel function is called
|
||||
|
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||
|
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
|
ctx, f := context.WithCancel(parent) |
||||
|
return ctx, CancelFunc(f) |
||||
|
} |
||||
|
|
||||
|
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
|
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
|
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
|
// context's Done channel is closed when the deadline expires, when the returned
|
||||
|
// cancel function is called, or when the parent context's Done channel is
|
||||
|
// closed, whichever happens first.
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||
|
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
|
ctx, f := context.WithDeadline(parent, deadline) |
||||
|
return ctx, CancelFunc(f) |
||||
|
} |
||||
|
|
||||
|
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete:
|
||||
|
//
|
||||
|
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
|
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
|
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
|
// return slowOperation(ctx)
|
||||
|
// }
|
||||
|
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
|
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
|
} |
||||
|
|
||||
|
// WithValue returns a copy of parent in which the value associated with key is
|
||||
|
// val.
|
||||
|
//
|
||||
|
// Use context Values only for request-scoped data that transits processes and
|
||||
|
// APIs, not for passing optional parameters to functions.
|
||||
|
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
|
return context.WithValue(parent, key, val) |
||||
|
} |
@ -0,0 +1,300 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build !go1.7
|
||||
|
|
||||
|
package context |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"sync" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
||||
|
// struct{}, since vars of this type must have distinct addresses.
|
||||
|
type emptyCtx int |
||||
|
|
||||
|
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
func (*emptyCtx) Done() <-chan struct{} { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (*emptyCtx) Err() error { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (*emptyCtx) Value(key interface{}) interface{} { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func (e *emptyCtx) String() string { |
||||
|
switch e { |
||||
|
case background: |
||||
|
return "context.Background" |
||||
|
case todo: |
||||
|
return "context.TODO" |
||||
|
} |
||||
|
return "unknown empty Context" |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
background = new(emptyCtx) |
||||
|
todo = new(emptyCtx) |
||||
|
) |
||||
|
|
||||
|
// Canceled is the error returned by Context.Err when the context is canceled.
|
||||
|
var Canceled = errors.New("context canceled") |
||||
|
|
||||
|
// DeadlineExceeded is the error returned by Context.Err when the context's
|
||||
|
// deadline passes.
|
||||
|
var DeadlineExceeded = errors.New("context deadline exceeded") |
||||
|
|
||||
|
// WithCancel returns a copy of parent with a new Done channel. The returned
|
||||
|
// context's Done channel is closed when the returned cancel function is called
|
||||
|
// or when the parent context's Done channel is closed, whichever happens first.
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||
|
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
||||
|
c := newCancelCtx(parent) |
||||
|
propagateCancel(parent, c) |
||||
|
return c, func() { c.cancel(true, Canceled) } |
||||
|
} |
||||
|
|
||||
|
// newCancelCtx returns an initialized cancelCtx.
|
||||
|
func newCancelCtx(parent Context) *cancelCtx { |
||||
|
return &cancelCtx{ |
||||
|
Context: parent, |
||||
|
done: make(chan struct{}), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// propagateCancel arranges for child to be canceled when parent is.
|
||||
|
func propagateCancel(parent Context, child canceler) { |
||||
|
if parent.Done() == nil { |
||||
|
return // parent is never canceled
|
||||
|
} |
||||
|
if p, ok := parentCancelCtx(parent); ok { |
||||
|
p.mu.Lock() |
||||
|
if p.err != nil { |
||||
|
// parent has already been canceled
|
||||
|
child.cancel(false, p.err) |
||||
|
} else { |
||||
|
if p.children == nil { |
||||
|
p.children = make(map[canceler]bool) |
||||
|
} |
||||
|
p.children[child] = true |
||||
|
} |
||||
|
p.mu.Unlock() |
||||
|
} else { |
||||
|
go func() { |
||||
|
select { |
||||
|
case <-parent.Done(): |
||||
|
child.cancel(false, parent.Err()) |
||||
|
case <-child.Done(): |
||||
|
} |
||||
|
}() |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// parentCancelCtx follows a chain of parent references until it finds a
|
||||
|
// *cancelCtx. This function understands how each of the concrete types in this
|
||||
|
// package represents its parent.
|
||||
|
func parentCancelCtx(parent Context) (*cancelCtx, bool) { |
||||
|
for { |
||||
|
switch c := parent.(type) { |
||||
|
case *cancelCtx: |
||||
|
return c, true |
||||
|
case *timerCtx: |
||||
|
return c.cancelCtx, true |
||||
|
case *valueCtx: |
||||
|
parent = c.Context |
||||
|
default: |
||||
|
return nil, false |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// removeChild removes a context from its parent.
|
||||
|
func removeChild(parent Context, child canceler) { |
||||
|
p, ok := parentCancelCtx(parent) |
||||
|
if !ok { |
||||
|
return |
||||
|
} |
||||
|
p.mu.Lock() |
||||
|
if p.children != nil { |
||||
|
delete(p.children, child) |
||||
|
} |
||||
|
p.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// A canceler is a context type that can be canceled directly. The
|
||||
|
// implementations are *cancelCtx and *timerCtx.
|
||||
|
type canceler interface { |
||||
|
cancel(removeFromParent bool, err error) |
||||
|
Done() <-chan struct{} |
||||
|
} |
||||
|
|
||||
|
// A cancelCtx can be canceled. When canceled, it also cancels any children
|
||||
|
// that implement canceler.
|
||||
|
type cancelCtx struct { |
||||
|
Context |
||||
|
|
||||
|
done chan struct{} // closed by the first cancel call.
|
||||
|
|
||||
|
mu sync.Mutex |
||||
|
children map[canceler]bool // set to nil by the first cancel call
|
||||
|
err error // set to non-nil by the first cancel call
|
||||
|
} |
||||
|
|
||||
|
func (c *cancelCtx) Done() <-chan struct{} { |
||||
|
return c.done |
||||
|
} |
||||
|
|
||||
|
func (c *cancelCtx) Err() error { |
||||
|
c.mu.Lock() |
||||
|
defer c.mu.Unlock() |
||||
|
return c.err |
||||
|
} |
||||
|
|
||||
|
func (c *cancelCtx) String() string { |
||||
|
return fmt.Sprintf("%v.WithCancel", c.Context) |
||||
|
} |
||||
|
|
||||
|
// cancel closes c.done, cancels each of c's children, and, if
|
||||
|
// removeFromParent is true, removes c from its parent's children.
|
||||
|
func (c *cancelCtx) cancel(removeFromParent bool, err error) { |
||||
|
if err == nil { |
||||
|
panic("context: internal error: missing cancel error") |
||||
|
} |
||||
|
c.mu.Lock() |
||||
|
if c.err != nil { |
||||
|
c.mu.Unlock() |
||||
|
return // already canceled
|
||||
|
} |
||||
|
c.err = err |
||||
|
close(c.done) |
||||
|
for child := range c.children { |
||||
|
// NOTE: acquiring the child's lock while holding parent's lock.
|
||||
|
child.cancel(false, err) |
||||
|
} |
||||
|
c.children = nil |
||||
|
c.mu.Unlock() |
||||
|
|
||||
|
if removeFromParent { |
||||
|
removeChild(c.Context, c) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
||||
|
// to be no later than d. If the parent's deadline is already earlier than d,
|
||||
|
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
||||
|
// context's Done channel is closed when the deadline expires, when the returned
|
||||
|
// cancel function is called, or when the parent context's Done channel is
|
||||
|
// closed, whichever happens first.
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete.
|
||||
|
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
||||
|
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { |
||||
|
// The current deadline is already sooner than the new one.
|
||||
|
return WithCancel(parent) |
||||
|
} |
||||
|
c := &timerCtx{ |
||||
|
cancelCtx: newCancelCtx(parent), |
||||
|
deadline: deadline, |
||||
|
} |
||||
|
propagateCancel(parent, c) |
||||
|
d := deadline.Sub(time.Now()) |
||||
|
if d <= 0 { |
||||
|
c.cancel(true, DeadlineExceeded) // deadline has already passed
|
||||
|
return c, func() { c.cancel(true, Canceled) } |
||||
|
} |
||||
|
c.mu.Lock() |
||||
|
defer c.mu.Unlock() |
||||
|
if c.err == nil { |
||||
|
c.timer = time.AfterFunc(d, func() { |
||||
|
c.cancel(true, DeadlineExceeded) |
||||
|
}) |
||||
|
} |
||||
|
return c, func() { c.cancel(true, Canceled) } |
||||
|
} |
||||
|
|
||||
|
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
|
||||
|
// implement Done and Err. It implements cancel by stopping its timer then
|
||||
|
// delegating to cancelCtx.cancel.
|
||||
|
type timerCtx struct { |
||||
|
*cancelCtx |
||||
|
timer *time.Timer // Under cancelCtx.mu.
|
||||
|
|
||||
|
deadline time.Time |
||||
|
} |
||||
|
|
||||
|
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { |
||||
|
return c.deadline, true |
||||
|
} |
||||
|
|
||||
|
func (c *timerCtx) String() string { |
||||
|
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) |
||||
|
} |
||||
|
|
||||
|
func (c *timerCtx) cancel(removeFromParent bool, err error) { |
||||
|
c.cancelCtx.cancel(false, err) |
||||
|
if removeFromParent { |
||||
|
// Remove this timerCtx from its parent cancelCtx's children.
|
||||
|
removeChild(c.cancelCtx.Context, c) |
||||
|
} |
||||
|
c.mu.Lock() |
||||
|
if c.timer != nil { |
||||
|
c.timer.Stop() |
||||
|
c.timer = nil |
||||
|
} |
||||
|
c.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
||||
|
//
|
||||
|
// Canceling this context releases resources associated with it, so code should
|
||||
|
// call cancel as soon as the operations running in this Context complete:
|
||||
|
//
|
||||
|
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
||||
|
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
|
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
||||
|
// return slowOperation(ctx)
|
||||
|
// }
|
||||
|
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
||||
|
return WithDeadline(parent, time.Now().Add(timeout)) |
||||
|
} |
||||
|
|
||||
|
// WithValue returns a copy of parent in which the value associated with key is
|
||||
|
// val.
|
||||
|
//
|
||||
|
// Use context Values only for request-scoped data that transits processes and
|
||||
|
// APIs, not for passing optional parameters to functions.
|
||||
|
func WithValue(parent Context, key interface{}, val interface{}) Context { |
||||
|
return &valueCtx{parent, key, val} |
||||
|
} |
||||
|
|
||||
|
// A valueCtx carries a key-value pair. It implements Value for that key and
|
||||
|
// delegates all other calls to the embedded Context.
|
||||
|
type valueCtx struct { |
||||
|
Context |
||||
|
key, val interface{} |
||||
|
} |
||||
|
|
||||
|
func (c *valueCtx) String() string { |
||||
|
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) |
||||
|
} |
||||
|
|
||||
|
func (c *valueCtx) Value(key interface{}) interface{} { |
||||
|
if c.key == key { |
||||
|
return c.val |
||||
|
} |
||||
|
return c.Context.Value(key) |
||||
|
} |
@ -0,0 +1,26 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package context_test |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
func ExampleWithTimeout() { |
||||
|
// Pass a context with a timeout to tell a blocking function that it
|
||||
|
// should abandon its work after the timeout elapses.
|
||||
|
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) |
||||
|
select { |
||||
|
case <-time.After(200 * time.Millisecond): |
||||
|
fmt.Println("overslept") |
||||
|
case <-ctx.Done(): |
||||
|
fmt.Println(ctx.Err()) // prints "context deadline exceeded"
|
||||
|
} |
||||
|
// Output:
|
||||
|
// context deadline exceeded
|
||||
|
} |
@ -0,0 +1,3 @@ |
|||||
|
# This source code refers to The Go Authors for copyright purposes. |
||||
|
# The master list of authors is in the main Go distribution, |
||||
|
# visible at http://tip.golang.org/AUTHORS. |
@ -0,0 +1,31 @@ |
|||||
|
# Contributing to Go |
||||
|
|
||||
|
Go is an open source project. |
||||
|
|
||||
|
It is the work of hundreds of contributors. We appreciate your help! |
||||
|
|
||||
|
|
||||
|
## Filing issues |
||||
|
|
||||
|
When [filing an issue](https://github.com/golang/oauth2/issues), make sure to answer these five questions: |
||||
|
|
||||
|
1. What version of Go are you using (`go version`)? |
||||
|
2. What operating system and processor architecture are you using? |
||||
|
3. What did you do? |
||||
|
4. What did you expect to see? |
||||
|
5. What did you see instead? |
||||
|
|
||||
|
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. |
||||
|
The gophers there will answer or ask you to file an issue if you've tripped over a bug. |
||||
|
|
||||
|
## Contributing code |
||||
|
|
||||
|
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) |
||||
|
before sending patches. |
||||
|
|
||||
|
**We do not accept GitHub pull requests** |
||||
|
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review). |
||||
|
|
||||
|
Unless otherwise noted, the Go source files are distributed under |
||||
|
the BSD-style license found in the LICENSE file. |
||||
|
|
@ -0,0 +1,3 @@ |
|||||
|
# This source code was written by the Go contributors. |
||||
|
# The master list of contributors is in the main Go distribution, |
||||
|
# visible at http://tip.golang.org/CONTRIBUTORS. |
@ -0,0 +1,27 @@ |
|||||
|
Copyright (c) 2009 The oauth2 Authors. All rights reserved. |
||||
|
|
||||
|
Redistribution and use in source and binary forms, with or without |
||||
|
modification, are permitted provided that the following conditions are |
||||
|
met: |
||||
|
|
||||
|
* Redistributions of source code must retain the above copyright |
||||
|
notice, this list of conditions and the following disclaimer. |
||||
|
* Redistributions in binary form must reproduce the above |
||||
|
copyright notice, this list of conditions and the following disclaimer |
||||
|
in the documentation and/or other materials provided with the |
||||
|
distribution. |
||||
|
* Neither the name of Google Inc. nor the names of its |
||||
|
contributors may be used to endorse or promote products derived from |
||||
|
this software without specific prior written permission. |
||||
|
|
||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@ -0,0 +1,64 @@ |
|||||
|
# OAuth2 for Go |
||||
|
|
||||
|
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2) |
||||
|
|
||||
|
oauth2 package contains a client implementation for OAuth 2.0 spec. |
||||
|
|
||||
|
## Installation |
||||
|
|
||||
|
~~~~ |
||||
|
go get golang.org/x/oauth2 |
||||
|
~~~~ |
||||
|
|
||||
|
See godoc for further documentation and examples. |
||||
|
|
||||
|
* [godoc.org/golang.org/x/oauth2](http://godoc.org/golang.org/x/oauth2) |
||||
|
* [godoc.org/golang.org/x/oauth2/google](http://godoc.org/golang.org/x/oauth2/google) |
||||
|
|
||||
|
|
||||
|
## App Engine |
||||
|
|
||||
|
In change 96e89be (March 2015) we removed the `oauth2.Context2` type in favor |
||||
|
of the [`context.Context`](https://golang.org/x/net/context#Context) type from |
||||
|
the `golang.org/x/net/context` package |
||||
|
|
||||
|
This means its no longer possible to use the "Classic App Engine" |
||||
|
`appengine.Context` type with the `oauth2` package. (You're using |
||||
|
Classic App Engine if you import the package `"appengine"`.) |
||||
|
|
||||
|
To work around this, you may use the new `"google.golang.org/appengine"` |
||||
|
package. This package has almost the same API as the `"appengine"` package, |
||||
|
but it can be fetched with `go get` and used on "Managed VMs" and well as |
||||
|
Classic App Engine. |
||||
|
|
||||
|
See the [new `appengine` package's readme](https://github.com/golang/appengine#updating-a-go-app-engine-app) |
||||
|
for information on updating your app. |
||||
|
|
||||
|
If you don't want to update your entire app to use the new App Engine packages, |
||||
|
you may use both sets of packages in parallel, using only the new packages |
||||
|
with the `oauth2` package. |
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/google" |
||||
|
newappengine "google.golang.org/appengine" |
||||
|
newurlfetch "google.golang.org/appengine/urlfetch" |
||||
|
|
||||
|
"appengine" |
||||
|
) |
||||
|
|
||||
|
func handler(w http.ResponseWriter, r *http.Request) { |
||||
|
var c appengine.Context = appengine.NewContext(r) |
||||
|
c.Infof("Logging a message with the old package") |
||||
|
|
||||
|
var ctx context.Context = newappengine.NewContext(r) |
||||
|
client := &http.Client{ |
||||
|
Transport: &oauth2.Transport{ |
||||
|
Source: google.AppEngineTokenSource(ctx, "scope"), |
||||
|
Base: &newurlfetch.Transport{Context: ctx}, |
||||
|
}, |
||||
|
} |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2015 The oauth2 Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package bitbucket provides constants for using OAuth2 to access Bitbucket.
|
||||
|
package bitbucket |
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Bitbucket's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://bitbucket.org/site/oauth2/authorize", |
||||
|
TokenURL: "https://bitbucket.org/site/oauth2/access_token", |
||||
|
} |
@ -0,0 +1,25 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build appengine
|
||||
|
|
||||
|
// App Engine hooks.
|
||||
|
|
||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
"google.golang.org/appengine/urlfetch" |
||||
|
) |
||||
|
|
||||
|
func init() { |
||||
|
internal.RegisterContextClientFunc(contextClientAppEngine) |
||||
|
} |
||||
|
|
||||
|
func contextClientAppEngine(ctx context.Context) (*http.Client, error) { |
||||
|
return urlfetch.Client(ctx), nil |
||||
|
} |
@ -0,0 +1,112 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package clientcredentials implements the OAuth2.0 "client credentials" token flow,
|
||||
|
// also known as the "two-legged OAuth 2.0".
|
||||
|
//
|
||||
|
// This should be used when the client is acting on its own behalf or when the client
|
||||
|
// is the resource owner. It may also be used when requesting access to protected
|
||||
|
// resources based on an authorization previously arranged with the authorization
|
||||
|
// server.
|
||||
|
//
|
||||
|
// See http://tools.ietf.org/html/draft-ietf-oauth-v2-31#section-4.4
|
||||
|
package clientcredentials // import "golang.org/x/oauth2/clientcredentials"
|
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
) |
||||
|
|
||||
|
// tokenFromInternal maps an *internal.Token struct into
|
||||
|
// an *oauth2.Token struct.
|
||||
|
func tokenFromInternal(t *internal.Token) *oauth2.Token { |
||||
|
if t == nil { |
||||
|
return nil |
||||
|
} |
||||
|
tk := &oauth2.Token{ |
||||
|
AccessToken: t.AccessToken, |
||||
|
TokenType: t.TokenType, |
||||
|
RefreshToken: t.RefreshToken, |
||||
|
Expiry: t.Expiry, |
||||
|
} |
||||
|
return tk.WithExtra(t.Raw) |
||||
|
} |
||||
|
|
||||
|
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
||||
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is
|
||||
|
// returned along with an error.
|
||||
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*oauth2.Token, error) { |
||||
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.TokenURL, v) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return tokenFromInternal(tk), nil |
||||
|
} |
||||
|
|
||||
|
// Client Credentials Config describes a 2-legged OAuth2 flow, with both the
|
||||
|
// client application information and the server's endpoint URLs.
|
||||
|
type Config struct { |
||||
|
// ClientID is the application's ID.
|
||||
|
ClientID string |
||||
|
|
||||
|
// ClientSecret is the application's secret.
|
||||
|
ClientSecret string |
||||
|
|
||||
|
// TokenURL is the resource server's token endpoint
|
||||
|
// URL. This is a constant specific to each server.
|
||||
|
TokenURL string |
||||
|
|
||||
|
// Scope specifies optional requested permissions.
|
||||
|
Scopes []string |
||||
|
} |
||||
|
|
||||
|
// Token uses client credentials to retrieve a token.
|
||||
|
// The HTTP client to use is derived from the context.
|
||||
|
// If nil, http.DefaultClient is used.
|
||||
|
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { |
||||
|
return retrieveToken(ctx, c, url.Values{ |
||||
|
"grant_type": {"client_credentials"}, |
||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// Client returns an HTTP client using the provided token.
|
||||
|
// The token will auto-refresh as necessary. The underlying
|
||||
|
// HTTP transport will be obtained using the provided context.
|
||||
|
// The returned client and its Transport should not be modified.
|
||||
|
func (c *Config) Client(ctx context.Context) *http.Client { |
||||
|
return oauth2.NewClient(ctx, c.TokenSource(ctx)) |
||||
|
} |
||||
|
|
||||
|
// TokenSource returns a TokenSource that returns t until t expires,
|
||||
|
// automatically refreshing it as necessary using the provided context and the
|
||||
|
// client ID and client secret.
|
||||
|
//
|
||||
|
// Most users will use Config.Client instead.
|
||||
|
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { |
||||
|
source := &tokenSource{ |
||||
|
ctx: ctx, |
||||
|
conf: c, |
||||
|
} |
||||
|
return oauth2.ReuseTokenSource(nil, source) |
||||
|
} |
||||
|
|
||||
|
type tokenSource struct { |
||||
|
ctx context.Context |
||||
|
conf *Config |
||||
|
} |
||||
|
|
||||
|
// Token refreshes the token by using a new client credentials request.
|
||||
|
// tokens received this way do not include a refresh token
|
||||
|
func (c *tokenSource) Token() (*oauth2.Token, error) { |
||||
|
return retrieveToken(c.ctx, c.conf, url.Values{ |
||||
|
"grant_type": {"client_credentials"}, |
||||
|
"scope": internal.CondVal(strings.Join(c.conf.Scopes, " ")), |
||||
|
}) |
||||
|
} |
@ -0,0 +1,96 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package clientcredentials |
||||
|
|
||||
|
import ( |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
func newConf(url string) *Config { |
||||
|
return &Config{ |
||||
|
ClientID: "CLIENT_ID", |
||||
|
ClientSecret: "CLIENT_SECRET", |
||||
|
Scopes: []string{"scope1", "scope2"}, |
||||
|
TokenURL: url + "/token", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
type mockTransport struct { |
||||
|
rt func(req *http.Request) (resp *http.Response, err error) |
||||
|
} |
||||
|
|
||||
|
func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { |
||||
|
return t.rt(req) |
||||
|
} |
||||
|
|
||||
|
func TestTokenRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token") |
||||
|
} |
||||
|
headerAuth := r.Header.Get("Authorization") |
||||
|
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
||||
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
||||
|
} |
||||
|
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { |
||||
|
t.Errorf("Content-Type header = %q; want %q", got, want) |
||||
|
} |
||||
|
body, err := ioutil.ReadAll(r.Body) |
||||
|
if err != nil { |
||||
|
r.Body.Close() |
||||
|
} |
||||
|
if err != nil { |
||||
|
t.Errorf("failed reading request body: %s.", err) |
||||
|
} |
||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { |
||||
|
t.Errorf("payload = %q; want %q", string(body), "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2") |
||||
|
} |
||||
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tok, err := conf.Token(oauth2.NoContext) |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Fatalf("token invalid. got: %#v", tok) |
||||
|
} |
||||
|
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
||||
|
t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") |
||||
|
} |
||||
|
if tok.TokenType != "bearer" { |
||||
|
t.Errorf("token type = %q; want %q", tok.TokenType, "bearer") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTokenRefreshRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() == "/somethingelse" { |
||||
|
return |
||||
|
} |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
if headerContentType != "application/x-www-form-urlencoded" { |
||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
||||
|
} |
||||
|
body, _ := ioutil.ReadAll(r.Body) |
||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { |
||||
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
||||
|
} |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
c := conf.Client(oauth2.NoContext) |
||||
|
c.Get(ts.URL + "/somethingelse") |
||||
|
} |
@ -0,0 +1,45 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package oauth2_test |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"log" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
func ExampleConfig() { |
||||
|
conf := &oauth2.Config{ |
||||
|
ClientID: "YOUR_CLIENT_ID", |
||||
|
ClientSecret: "YOUR_CLIENT_SECRET", |
||||
|
Scopes: []string{"SCOPE1", "SCOPE2"}, |
||||
|
Endpoint: oauth2.Endpoint{ |
||||
|
AuthURL: "https://provider.com/o/oauth2/auth", |
||||
|
TokenURL: "https://provider.com/o/oauth2/token", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
// Redirect user to consent page to ask for permission
|
||||
|
// for the scopes specified above.
|
||||
|
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) |
||||
|
fmt.Printf("Visit the URL for the auth dialog: %v", url) |
||||
|
|
||||
|
// Use the authorization code that is pushed to the redirect
|
||||
|
// URL. Exchange will do the handshake to retrieve the
|
||||
|
// initial access token. The HTTP Client returned by
|
||||
|
// conf.Client will refresh the token as necessary.
|
||||
|
var code string |
||||
|
if _, err := fmt.Scan(&code); err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
tok, err := conf.Exchange(oauth2.NoContext, code) |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
client := conf.Client(oauth2.NoContext, tok) |
||||
|
client.Get("...") |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package facebook provides constants for using OAuth2 to access Facebook.
|
||||
|
package facebook // import "golang.org/x/oauth2/facebook"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Facebook's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.facebook.com/dialog/oauth", |
||||
|
TokenURL: "https://graph.facebook.com/oauth/access_token", |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package fitbit provides constants for using OAuth2 to access the Fitbit API.
|
||||
|
package fitbit // import "golang.org/x/oauth2/fitbit"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is the Fitbit API's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.fitbit.com/oauth2/authorize", |
||||
|
TokenURL: "https://api.fitbit.com/oauth2/token", |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package github provides constants for using OAuth2 to access Github.
|
||||
|
package github // import "golang.org/x/oauth2/github"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Github's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://github.com/login/oauth/authorize", |
||||
|
TokenURL: "https://github.com/login/oauth/access_token", |
||||
|
} |
@ -0,0 +1,86 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"sort" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Set at init time by appenginevm_hook.go. If true, we are on App Engine Managed VMs.
|
||||
|
var appengineVM bool |
||||
|
|
||||
|
// Set at init time by appengine_hook.go. If nil, we're not on App Engine.
|
||||
|
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error) |
||||
|
|
||||
|
// AppEngineTokenSource returns a token source that fetches tokens
|
||||
|
// issued to the current App Engine application's service account.
|
||||
|
// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
|
||||
|
// that involves user accounts, see oauth2.Config instead.
|
||||
|
//
|
||||
|
// The provided context must have come from appengine.NewContext.
|
||||
|
func AppEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource { |
||||
|
if appengineTokenFunc == nil { |
||||
|
panic("google: AppEngineTokenSource can only be used on App Engine.") |
||||
|
} |
||||
|
scopes := append([]string{}, scope...) |
||||
|
sort.Strings(scopes) |
||||
|
return &appEngineTokenSource{ |
||||
|
ctx: ctx, |
||||
|
scopes: scopes, |
||||
|
key: strings.Join(scopes, " "), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// aeTokens helps the fetched tokens to be reused until their expiration.
|
||||
|
var ( |
||||
|
aeTokensMu sync.Mutex |
||||
|
aeTokens = make(map[string]*tokenLock) // key is space-separated scopes
|
||||
|
) |
||||
|
|
||||
|
type tokenLock struct { |
||||
|
mu sync.Mutex // guards t; held while fetching or updating t
|
||||
|
t *oauth2.Token |
||||
|
} |
||||
|
|
||||
|
type appEngineTokenSource struct { |
||||
|
ctx context.Context |
||||
|
scopes []string |
||||
|
key string // to aeTokens map; space-separated scopes
|
||||
|
} |
||||
|
|
||||
|
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { |
||||
|
if appengineTokenFunc == nil { |
||||
|
panic("google: AppEngineTokenSource can only be used on App Engine.") |
||||
|
} |
||||
|
|
||||
|
aeTokensMu.Lock() |
||||
|
tok, ok := aeTokens[ts.key] |
||||
|
if !ok { |
||||
|
tok = &tokenLock{} |
||||
|
aeTokens[ts.key] = tok |
||||
|
} |
||||
|
aeTokensMu.Unlock() |
||||
|
|
||||
|
tok.mu.Lock() |
||||
|
defer tok.mu.Unlock() |
||||
|
if tok.t.Valid() { |
||||
|
return tok.t, nil |
||||
|
} |
||||
|
access, exp, err := appengineTokenFunc(ts.ctx, ts.scopes...) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
tok.t = &oauth2.Token{ |
||||
|
AccessToken: access, |
||||
|
Expiry: exp, |
||||
|
} |
||||
|
return tok.t, nil |
||||
|
} |
@ -0,0 +1,13 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build appengine
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import "google.golang.org/appengine" |
||||
|
|
||||
|
func init() { |
||||
|
appengineTokenFunc = appengine.AccessToken |
||||
|
} |
@ -0,0 +1,14 @@ |
|||||
|
// Copyright 2015 The oauth2 Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build appenginevm
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import "google.golang.org/appengine" |
||||
|
|
||||
|
func init() { |
||||
|
appengineVM = true |
||||
|
appengineTokenFunc = appengine.AccessToken |
||||
|
} |
@ -0,0 +1,155 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"os" |
||||
|
"path/filepath" |
||||
|
"runtime" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/jwt" |
||||
|
"google.golang.org/cloud/compute/metadata" |
||||
|
) |
||||
|
|
||||
|
// DefaultClient returns an HTTP Client that uses the
|
||||
|
// DefaultTokenSource to obtain authentication credentials.
|
||||
|
//
|
||||
|
// This client should be used when developing services
|
||||
|
// that run on Google App Engine or Google Compute Engine
|
||||
|
// and use "Application Default Credentials."
|
||||
|
//
|
||||
|
// For more details, see:
|
||||
|
// https://developers.google.com/accounts/docs/application-default-credentials
|
||||
|
//
|
||||
|
func DefaultClient(ctx context.Context, scope ...string) (*http.Client, error) { |
||||
|
ts, err := DefaultTokenSource(ctx, scope...) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return oauth2.NewClient(ctx, ts), nil |
||||
|
} |
||||
|
|
||||
|
// DefaultTokenSource is a token source that uses
|
||||
|
// "Application Default Credentials".
|
||||
|
//
|
||||
|
// It looks for credentials in the following places,
|
||||
|
// preferring the first location found:
|
||||
|
//
|
||||
|
// 1. A JSON file whose path is specified by the
|
||||
|
// GOOGLE_APPLICATION_CREDENTIALS environment variable.
|
||||
|
// 2. A JSON file in a location known to the gcloud command-line tool.
|
||||
|
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
|
||||
|
// On other systems, $HOME/.config/gcloud/application_default_credentials.json.
|
||||
|
// 3. On Google App Engine it uses the appengine.AccessToken function.
|
||||
|
// 4. On Google Compute Engine and Google App Engine Managed VMs, it fetches
|
||||
|
// credentials from the metadata server.
|
||||
|
// (In this final case any provided scopes are ignored.)
|
||||
|
//
|
||||
|
// For more details, see:
|
||||
|
// https://developers.google.com/accounts/docs/application-default-credentials
|
||||
|
//
|
||||
|
func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSource, error) { |
||||
|
// First, try the environment variable.
|
||||
|
const envVar = "GOOGLE_APPLICATION_CREDENTIALS" |
||||
|
if filename := os.Getenv(envVar); filename != "" { |
||||
|
ts, err := tokenSourceFromFile(ctx, filename, scope) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("google: error getting credentials using %v environment variable: %v", envVar, err) |
||||
|
} |
||||
|
return ts, nil |
||||
|
} |
||||
|
|
||||
|
// Second, try a well-known file.
|
||||
|
filename := wellKnownFile() |
||||
|
_, err := os.Stat(filename) |
||||
|
if err == nil { |
||||
|
ts, err2 := tokenSourceFromFile(ctx, filename, scope) |
||||
|
if err2 == nil { |
||||
|
return ts, nil |
||||
|
} |
||||
|
err = err2 |
||||
|
} else if os.IsNotExist(err) { |
||||
|
err = nil // ignore this error
|
||||
|
} |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("google: error getting credentials using well-known file (%v): %v", filename, err) |
||||
|
} |
||||
|
|
||||
|
// Third, if we're on Google App Engine use those credentials.
|
||||
|
if appengineTokenFunc != nil && !appengineVM { |
||||
|
return AppEngineTokenSource(ctx, scope...), nil |
||||
|
} |
||||
|
|
||||
|
// Fourth, if we're on Google Compute Engine use the metadata server.
|
||||
|
if metadata.OnGCE() { |
||||
|
return ComputeTokenSource(""), nil |
||||
|
} |
||||
|
|
||||
|
// None are found; return helpful error.
|
||||
|
const url = "https://developers.google.com/accounts/docs/application-default-credentials" |
||||
|
return nil, fmt.Errorf("google: could not find default credentials. See %v for more information.", url) |
||||
|
} |
||||
|
|
||||
|
func wellKnownFile() string { |
||||
|
const f = "application_default_credentials.json" |
||||
|
if runtime.GOOS == "windows" { |
||||
|
return filepath.Join(os.Getenv("APPDATA"), "gcloud", f) |
||||
|
} |
||||
|
return filepath.Join(guessUnixHomeDir(), ".config", "gcloud", f) |
||||
|
} |
||||
|
|
||||
|
func tokenSourceFromFile(ctx context.Context, filename string, scopes []string) (oauth2.TokenSource, error) { |
||||
|
b, err := ioutil.ReadFile(filename) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
var d struct { |
||||
|
// Common fields
|
||||
|
Type string |
||||
|
ClientID string `json:"client_id"` |
||||
|
|
||||
|
// User Credential fields
|
||||
|
ClientSecret string `json:"client_secret"` |
||||
|
RefreshToken string `json:"refresh_token"` |
||||
|
|
||||
|
// Service Account fields
|
||||
|
ClientEmail string `json:"client_email"` |
||||
|
PrivateKeyID string `json:"private_key_id"` |
||||
|
PrivateKey string `json:"private_key"` |
||||
|
} |
||||
|
if err := json.Unmarshal(b, &d); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
switch d.Type { |
||||
|
case "authorized_user": |
||||
|
cfg := &oauth2.Config{ |
||||
|
ClientID: d.ClientID, |
||||
|
ClientSecret: d.ClientSecret, |
||||
|
Scopes: append([]string{}, scopes...), // copy
|
||||
|
Endpoint: Endpoint, |
||||
|
} |
||||
|
tok := &oauth2.Token{RefreshToken: d.RefreshToken} |
||||
|
return cfg.TokenSource(ctx, tok), nil |
||||
|
case "service_account": |
||||
|
cfg := &jwt.Config{ |
||||
|
Email: d.ClientEmail, |
||||
|
PrivateKey: []byte(d.PrivateKey), |
||||
|
Scopes: append([]string{}, scopes...), // copy
|
||||
|
TokenURL: JWTTokenURL, |
||||
|
} |
||||
|
return cfg.TokenSource(ctx), nil |
||||
|
case "": |
||||
|
return nil, errors.New("missing 'type' field in credentials") |
||||
|
default: |
||||
|
return nil, fmt.Errorf("unknown credential type: %q", d.Type) |
||||
|
} |
||||
|
} |
@ -0,0 +1,150 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// +build appenginevm appengine
|
||||
|
|
||||
|
package google_test |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"io/ioutil" |
||||
|
"log" |
||||
|
"net/http" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/google" |
||||
|
"golang.org/x/oauth2/jwt" |
||||
|
"google.golang.org/appengine" |
||||
|
"google.golang.org/appengine/urlfetch" |
||||
|
) |
||||
|
|
||||
|
func ExampleDefaultClient() { |
||||
|
client, err := google.DefaultClient(oauth2.NoContext, |
||||
|
"https://www.googleapis.com/auth/devstorage.full_control") |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func Example_webServer() { |
||||
|
// Your credentials should be obtained from the Google
|
||||
|
// Developer Console (https://console.developers.google.com).
|
||||
|
conf := &oauth2.Config{ |
||||
|
ClientID: "YOUR_CLIENT_ID", |
||||
|
ClientSecret: "YOUR_CLIENT_SECRET", |
||||
|
RedirectURL: "YOUR_REDIRECT_URL", |
||||
|
Scopes: []string{ |
||||
|
"https://www.googleapis.com/auth/bigquery", |
||||
|
"https://www.googleapis.com/auth/blogger", |
||||
|
}, |
||||
|
Endpoint: google.Endpoint, |
||||
|
} |
||||
|
// Redirect user to Google's consent page to ask for permission
|
||||
|
// for the scopes specified above.
|
||||
|
url := conf.AuthCodeURL("state") |
||||
|
fmt.Printf("Visit the URL for the auth dialog: %v", url) |
||||
|
|
||||
|
// Handle the exchange code to initiate a transport.
|
||||
|
tok, err := conf.Exchange(oauth2.NoContext, "authorization-code") |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
client := conf.Client(oauth2.NoContext, tok) |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func ExampleJWTConfigFromJSON() { |
||||
|
// Your credentials should be obtained from the Google
|
||||
|
// Developer Console (https://console.developers.google.com).
|
||||
|
// Navigate to your project, then see the "Credentials" page
|
||||
|
// under "APIs & Auth".
|
||||
|
// To create a service account client, click "Create new Client ID",
|
||||
|
// select "Service Account", and click "Create Client ID". A JSON
|
||||
|
// key file will then be downloaded to your computer.
|
||||
|
data, err := ioutil.ReadFile("/path/to/your-project-key.json") |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
conf, err := google.JWTConfigFromJSON(data, "https://www.googleapis.com/auth/bigquery") |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
// Initiate an http.Client. The following GET request will be
|
||||
|
// authorized and authenticated on the behalf of
|
||||
|
// your service account.
|
||||
|
client := conf.Client(oauth2.NoContext) |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func ExampleSDKConfig() { |
||||
|
// The credentials will be obtained from the first account that
|
||||
|
// has been authorized with `gcloud auth login`.
|
||||
|
conf, err := google.NewSDKConfig("") |
||||
|
if err != nil { |
||||
|
log.Fatal(err) |
||||
|
} |
||||
|
// Initiate an http.Client. The following GET request will be
|
||||
|
// authorized and authenticated on the behalf of the SDK user.
|
||||
|
client := conf.Client(oauth2.NoContext) |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func Example_serviceAccount() { |
||||
|
// Your credentials should be obtained from the Google
|
||||
|
// Developer Console (https://console.developers.google.com).
|
||||
|
conf := &jwt.Config{ |
||||
|
Email: "xxx@developer.gserviceaccount.com", |
||||
|
// The contents of your RSA private key or your PEM file
|
||||
|
// that contains a private key.
|
||||
|
// If you have a p12 file instead, you
|
||||
|
// can use `openssl` to export the private key into a pem file.
|
||||
|
//
|
||||
|
// $ openssl pkcs12 -in key.p12 -passin pass:notasecret -out key.pem -nodes
|
||||
|
//
|
||||
|
// The field only supports PEM containers with no passphrase.
|
||||
|
// The openssl command will convert p12 keys to passphrase-less PEM containers.
|
||||
|
PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), |
||||
|
Scopes: []string{ |
||||
|
"https://www.googleapis.com/auth/bigquery", |
||||
|
"https://www.googleapis.com/auth/blogger", |
||||
|
}, |
||||
|
TokenURL: google.JWTTokenURL, |
||||
|
// If you would like to impersonate a user, you can
|
||||
|
// create a transport with a subject. The following GET
|
||||
|
// request will be made on the behalf of user@example.com.
|
||||
|
// Optional.
|
||||
|
Subject: "user@example.com", |
||||
|
} |
||||
|
// Initiate an http.Client, the following GET request will be
|
||||
|
// authorized and authenticated on the behalf of user@example.com.
|
||||
|
client := conf.Client(oauth2.NoContext) |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func ExampleAppEngineTokenSource() { |
||||
|
var req *http.Request // from the ServeHTTP handler
|
||||
|
ctx := appengine.NewContext(req) |
||||
|
client := &http.Client{ |
||||
|
Transport: &oauth2.Transport{ |
||||
|
Source: google.AppEngineTokenSource(ctx, "https://www.googleapis.com/auth/bigquery"), |
||||
|
Base: &urlfetch.Transport{ |
||||
|
Context: ctx, |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
client.Get("...") |
||||
|
} |
||||
|
|
||||
|
func ExampleComputeTokenSource() { |
||||
|
client := &http.Client{ |
||||
|
Transport: &oauth2.Transport{ |
||||
|
// Fetch from Google Compute Engine's metadata server to retrieve
|
||||
|
// an access token for the provided account.
|
||||
|
// If no account is specified, "default" is used.
|
||||
|
Source: google.ComputeTokenSource(""), |
||||
|
}, |
||||
|
} |
||||
|
client.Get("...") |
||||
|
} |
@ -0,0 +1,149 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package google provides support for making OAuth2 authorized and
|
||||
|
// authenticated HTTP requests to Google APIs.
|
||||
|
// It supports the Web server flow, client-side credentials, service accounts,
|
||||
|
// Google Compute Engine service accounts, and Google App Engine service
|
||||
|
// accounts.
|
||||
|
//
|
||||
|
// For more information, please read
|
||||
|
// https://developers.google.com/accounts/docs/OAuth2
|
||||
|
// and
|
||||
|
// https://developers.google.com/accounts/docs/application-default-credentials.
|
||||
|
package google // import "golang.org/x/oauth2/google"
|
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"strings" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/jwt" |
||||
|
"google.golang.org/cloud/compute/metadata" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Google's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth", |
||||
|
TokenURL: "https://accounts.google.com/o/oauth2/token", |
||||
|
} |
||||
|
|
||||
|
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
|
||||
|
const JWTTokenURL = "https://accounts.google.com/o/oauth2/token" |
||||
|
|
||||
|
// ConfigFromJSON uses a Google Developers Console client_credentials.json
|
||||
|
// file to construct a config.
|
||||
|
// client_credentials.json can be downloaded from
|
||||
|
// https://console.developers.google.com, under "Credentials". Download the Web
|
||||
|
// application credentials in the JSON format and provide the contents of the
|
||||
|
// file as jsonKey.
|
||||
|
func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) { |
||||
|
type cred struct { |
||||
|
ClientID string `json:"client_id"` |
||||
|
ClientSecret string `json:"client_secret"` |
||||
|
RedirectURIs []string `json:"redirect_uris"` |
||||
|
AuthURI string `json:"auth_uri"` |
||||
|
TokenURI string `json:"token_uri"` |
||||
|
} |
||||
|
var j struct { |
||||
|
Web *cred `json:"web"` |
||||
|
Installed *cred `json:"installed"` |
||||
|
} |
||||
|
if err := json.Unmarshal(jsonKey, &j); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
var c *cred |
||||
|
switch { |
||||
|
case j.Web != nil: |
||||
|
c = j.Web |
||||
|
case j.Installed != nil: |
||||
|
c = j.Installed |
||||
|
default: |
||||
|
return nil, fmt.Errorf("oauth2/google: no credentials found") |
||||
|
} |
||||
|
if len(c.RedirectURIs) < 1 { |
||||
|
return nil, errors.New("oauth2/google: missing redirect URL in the client_credentials.json") |
||||
|
} |
||||
|
return &oauth2.Config{ |
||||
|
ClientID: c.ClientID, |
||||
|
ClientSecret: c.ClientSecret, |
||||
|
RedirectURL: c.RedirectURIs[0], |
||||
|
Scopes: scope, |
||||
|
Endpoint: oauth2.Endpoint{ |
||||
|
AuthURL: c.AuthURI, |
||||
|
TokenURL: c.TokenURI, |
||||
|
}, |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
|
||||
|
// the credentials that authorize and authenticate the requests.
|
||||
|
// Create a service account on "Credentials" for your project at
|
||||
|
// https://console.developers.google.com to download a JSON key file.
|
||||
|
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) { |
||||
|
var key struct { |
||||
|
Email string `json:"client_email"` |
||||
|
PrivateKey string `json:"private_key"` |
||||
|
PrivateKeyID string `json:"private_key_id"` |
||||
|
} |
||||
|
if err := json.Unmarshal(jsonKey, &key); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
config := &jwt.Config{ |
||||
|
Email: key.Email, |
||||
|
PrivateKey: []byte(key.PrivateKey), |
||||
|
PrivateKeyID: key.PrivateKeyID, |
||||
|
Scopes: scope, |
||||
|
TokenURL: JWTTokenURL, |
||||
|
} |
||||
|
return config, nil |
||||
|
} |
||||
|
|
||||
|
// ComputeTokenSource returns a token source that fetches access tokens
|
||||
|
// from Google Compute Engine (GCE)'s metadata server. It's only valid to use
|
||||
|
// this token source if your program is running on a GCE instance.
|
||||
|
// If no account is specified, "default" is used.
|
||||
|
// Further information about retrieving access tokens from the GCE metadata
|
||||
|
// server can be found at https://cloud.google.com/compute/docs/authentication.
|
||||
|
func ComputeTokenSource(account string) oauth2.TokenSource { |
||||
|
return oauth2.ReuseTokenSource(nil, computeSource{account: account}) |
||||
|
} |
||||
|
|
||||
|
type computeSource struct { |
||||
|
account string |
||||
|
} |
||||
|
|
||||
|
func (cs computeSource) Token() (*oauth2.Token, error) { |
||||
|
if !metadata.OnGCE() { |
||||
|
return nil, errors.New("oauth2/google: can't get a token from the metadata service; not running on GCE") |
||||
|
} |
||||
|
acct := cs.account |
||||
|
if acct == "" { |
||||
|
acct = "default" |
||||
|
} |
||||
|
tokenJSON, err := metadata.Get("instance/service-accounts/" + acct + "/token") |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
var res struct { |
||||
|
AccessToken string `json:"access_token"` |
||||
|
ExpiresInSec int `json:"expires_in"` |
||||
|
TokenType string `json:"token_type"` |
||||
|
} |
||||
|
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err) |
||||
|
} |
||||
|
if res.ExpiresInSec == 0 || res.AccessToken == "" { |
||||
|
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata") |
||||
|
} |
||||
|
return &oauth2.Token{ |
||||
|
AccessToken: res.AccessToken, |
||||
|
TokenType: res.TokenType, |
||||
|
Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second), |
||||
|
}, nil |
||||
|
} |
@ -0,0 +1,97 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"strings" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
var webJSONKey = []byte(` |
||||
|
{ |
||||
|
"web": { |
||||
|
"auth_uri": "https://google.com/o/oauth2/auth", |
||||
|
"client_secret": "3Oknc4jS_wA2r9i", |
||||
|
"token_uri": "https://google.com/o/oauth2/token", |
||||
|
"client_email": "222-nprqovg5k43uum874cs9osjt2koe97g8@developer.gserviceaccount.com", |
||||
|
"redirect_uris": ["https://www.example.com/oauth2callback"], |
||||
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/222-nprqovg5k43uum874cs9osjt2koe97g8@developer.gserviceaccount.com", |
||||
|
"client_id": "222-nprqovg5k43uum874cs9osjt2koe97g8.apps.googleusercontent.com", |
||||
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
||||
|
"javascript_origins": ["https://www.example.com"] |
||||
|
} |
||||
|
}`) |
||||
|
|
||||
|
var installedJSONKey = []byte(`{ |
||||
|
"installed": { |
||||
|
"client_id": "222-installed.apps.googleusercontent.com", |
||||
|
"redirect_uris": ["https://www.example.com/oauth2callback"] |
||||
|
} |
||||
|
}`) |
||||
|
|
||||
|
var jwtJSONKey = []byte(`{ |
||||
|
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b", |
||||
|
"private_key": "super secret key", |
||||
|
"client_email": "gopher@developer.gserviceaccount.com", |
||||
|
"client_id": "gopher.apps.googleusercontent.com", |
||||
|
"type": "service_account" |
||||
|
}`) |
||||
|
|
||||
|
func TestConfigFromJSON(t *testing.T) { |
||||
|
conf, err := ConfigFromJSON(webJSONKey, "scope1", "scope2") |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if got, want := conf.ClientID, "222-nprqovg5k43uum874cs9osjt2koe97g8.apps.googleusercontent.com"; got != want { |
||||
|
t.Errorf("ClientID = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.ClientSecret, "3Oknc4jS_wA2r9i"; got != want { |
||||
|
t.Errorf("ClientSecret = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want { |
||||
|
t.Errorf("RedictURL = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want { |
||||
|
t.Errorf("Scopes = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.Endpoint.AuthURL, "https://google.com/o/oauth2/auth"; got != want { |
||||
|
t.Errorf("AuthURL = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.Endpoint.TokenURL, "https://google.com/o/oauth2/token"; got != want { |
||||
|
t.Errorf("TokenURL = %q; want %q", got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestConfigFromJSON_Installed(t *testing.T) { |
||||
|
conf, err := ConfigFromJSON(installedJSONKey) |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if got, want := conf.ClientID, "222-installed.apps.googleusercontent.com"; got != want { |
||||
|
t.Errorf("ClientID = %q; want %q", got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestJWTConfigFromJSON(t *testing.T) { |
||||
|
conf, err := JWTConfigFromJSON(jwtJSONKey, "scope1", "scope2") |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if got, want := conf.Email, "gopher@developer.gserviceaccount.com"; got != want { |
||||
|
t.Errorf("Email = %q, want %q", got, want) |
||||
|
} |
||||
|
if got, want := string(conf.PrivateKey), "super secret key"; got != want { |
||||
|
t.Errorf("PrivateKey = %q, want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.PrivateKeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { |
||||
|
t.Errorf("PrivateKeyID = %q, want %q", got, want) |
||||
|
} |
||||
|
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want { |
||||
|
t.Errorf("Scopes = %q; want %q", got, want) |
||||
|
} |
||||
|
if got, want := conf.TokenURL, "https://accounts.google.com/o/oauth2/token"; got != want { |
||||
|
t.Errorf("TokenURL = %q; want %q", got, want) |
||||
|
} |
||||
|
} |
@ -0,0 +1,74 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"crypto/rsa" |
||||
|
"fmt" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
"golang.org/x/oauth2/jws" |
||||
|
) |
||||
|
|
||||
|
// JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON
|
||||
|
// key file to read the credentials that authorize and authenticate the
|
||||
|
// requests, and returns a TokenSource that does not use any OAuth2 flow but
|
||||
|
// instead creates a JWT and sends that as the access token.
|
||||
|
// The audience is typically a URL that specifies the scope of the credentials.
|
||||
|
//
|
||||
|
// Note that this is not a standard OAuth flow, but rather an
|
||||
|
// optimization supported by a few Google services.
|
||||
|
// Unless you know otherwise, you should use JWTConfigFromJSON instead.
|
||||
|
func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) { |
||||
|
cfg, err := JWTConfigFromJSON(jsonKey) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("google: could not parse JSON key: %v", err) |
||||
|
} |
||||
|
pk, err := internal.ParseKey(cfg.PrivateKey) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("google: could not parse key: %v", err) |
||||
|
} |
||||
|
ts := &jwtAccessTokenSource{ |
||||
|
email: cfg.Email, |
||||
|
audience: audience, |
||||
|
pk: pk, |
||||
|
pkID: cfg.PrivateKeyID, |
||||
|
} |
||||
|
tok, err := ts.Token() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return oauth2.ReuseTokenSource(tok, ts), nil |
||||
|
} |
||||
|
|
||||
|
type jwtAccessTokenSource struct { |
||||
|
email, audience string |
||||
|
pk *rsa.PrivateKey |
||||
|
pkID string |
||||
|
} |
||||
|
|
||||
|
func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) { |
||||
|
iat := time.Now() |
||||
|
exp := iat.Add(time.Hour) |
||||
|
cs := &jws.ClaimSet{ |
||||
|
Iss: ts.email, |
||||
|
Sub: ts.email, |
||||
|
Aud: ts.audience, |
||||
|
Iat: iat.Unix(), |
||||
|
Exp: exp.Unix(), |
||||
|
} |
||||
|
hdr := &jws.Header{ |
||||
|
Algorithm: "RS256", |
||||
|
Typ: "JWT", |
||||
|
KeyID: string(ts.pkID), |
||||
|
} |
||||
|
msg, err := jws.Encode(hdr, cs, ts.pk) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("google: could not encode JWT: %v", err) |
||||
|
} |
||||
|
return &oauth2.Token{AccessToken: msg, TokenType: "Bearer", Expiry: exp}, nil |
||||
|
} |
@ -0,0 +1,91 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"crypto/x509" |
||||
|
"encoding/base64" |
||||
|
"encoding/json" |
||||
|
"encoding/pem" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/oauth2/jws" |
||||
|
) |
||||
|
|
||||
|
func TestJWTAccessTokenSourceFromJSON(t *testing.T) { |
||||
|
// Generate a key we can use in the test data.
|
||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
// Encode the key and substitute into our example JSON.
|
||||
|
enc := pem.EncodeToMemory(&pem.Block{ |
||||
|
Type: "PRIVATE KEY", |
||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey), |
||||
|
}) |
||||
|
enc, err = json.Marshal(string(enc)) |
||||
|
if err != nil { |
||||
|
t.Fatalf("json.Marshal: %v", err) |
||||
|
} |
||||
|
jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1) |
||||
|
|
||||
|
ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience") |
||||
|
if err != nil { |
||||
|
t.Fatalf("JWTAccessTokenSourceFromJSON: %v\nJSON: %s", err, string(jsonKey)) |
||||
|
} |
||||
|
|
||||
|
tok, err := ts.Token() |
||||
|
if err != nil { |
||||
|
t.Fatalf("Token: %v", err) |
||||
|
} |
||||
|
|
||||
|
if got, want := tok.TokenType, "Bearer"; got != want { |
||||
|
t.Errorf("TokenType = %q, want %q", got, want) |
||||
|
} |
||||
|
if got := tok.Expiry; tok.Expiry.Before(time.Now()) { |
||||
|
t.Errorf("Expiry = %v, should not be expired", got) |
||||
|
} |
||||
|
|
||||
|
err = jws.Verify(tok.AccessToken, &privateKey.PublicKey) |
||||
|
if err != nil { |
||||
|
t.Errorf("jws.Verify on AccessToken: %v", err) |
||||
|
} |
||||
|
|
||||
|
claim, err := jws.Decode(tok.AccessToken) |
||||
|
if err != nil { |
||||
|
t.Fatalf("jws.Decode on AccessToken: %v", err) |
||||
|
} |
||||
|
|
||||
|
if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want { |
||||
|
t.Errorf("Iss = %q, want %q", got, want) |
||||
|
} |
||||
|
if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want { |
||||
|
t.Errorf("Sub = %q, want %q", got, want) |
||||
|
} |
||||
|
if got, want := claim.Aud, "audience"; got != want { |
||||
|
t.Errorf("Aud = %q, want %q", got, want) |
||||
|
} |
||||
|
|
||||
|
// Finally, check the header private key.
|
||||
|
parts := strings.Split(tok.AccessToken, ".") |
||||
|
hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) |
||||
|
if err != nil { |
||||
|
t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0]) |
||||
|
} |
||||
|
var hdr jws.Header |
||||
|
if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil { |
||||
|
t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON) |
||||
|
} |
||||
|
|
||||
|
if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { |
||||
|
t.Errorf("Header KeyID = %q, want %q", got, want) |
||||
|
} |
||||
|
} |
@ -0,0 +1,168 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
"os" |
||||
|
"os/user" |
||||
|
"path/filepath" |
||||
|
"runtime" |
||||
|
"strings" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
) |
||||
|
|
||||
|
type sdkCredentials struct { |
||||
|
Data []struct { |
||||
|
Credential struct { |
||||
|
ClientID string `json:"client_id"` |
||||
|
ClientSecret string `json:"client_secret"` |
||||
|
AccessToken string `json:"access_token"` |
||||
|
RefreshToken string `json:"refresh_token"` |
||||
|
TokenExpiry *time.Time `json:"token_expiry"` |
||||
|
} `json:"credential"` |
||||
|
Key struct { |
||||
|
Account string `json:"account"` |
||||
|
Scope string `json:"scope"` |
||||
|
} `json:"key"` |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// An SDKConfig provides access to tokens from an account already
|
||||
|
// authorized via the Google Cloud SDK.
|
||||
|
type SDKConfig struct { |
||||
|
conf oauth2.Config |
||||
|
initialToken *oauth2.Token |
||||
|
} |
||||
|
|
||||
|
// NewSDKConfig creates an SDKConfig for the given Google Cloud SDK
|
||||
|
// account. If account is empty, the account currently active in
|
||||
|
// Google Cloud SDK properties is used.
|
||||
|
// Google Cloud SDK credentials must be created by running `gcloud auth`
|
||||
|
// before using this function.
|
||||
|
// The Google Cloud SDK is available at https://cloud.google.com/sdk/.
|
||||
|
func NewSDKConfig(account string) (*SDKConfig, error) { |
||||
|
configPath, err := sdkConfigPath() |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: error getting SDK config path: %v", err) |
||||
|
} |
||||
|
credentialsPath := filepath.Join(configPath, "credentials") |
||||
|
f, err := os.Open(credentialsPath) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to load SDK credentials: %v", err) |
||||
|
} |
||||
|
defer f.Close() |
||||
|
|
||||
|
var c sdkCredentials |
||||
|
if err := json.NewDecoder(f).Decode(&c); err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to decode SDK credentials from %q: %v", credentialsPath, err) |
||||
|
} |
||||
|
if len(c.Data) == 0 { |
||||
|
return nil, fmt.Errorf("oauth2/google: no credentials found in %q, run `gcloud auth login` to create one", credentialsPath) |
||||
|
} |
||||
|
if account == "" { |
||||
|
propertiesPath := filepath.Join(configPath, "properties") |
||||
|
f, err := os.Open(propertiesPath) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to load SDK properties: %v", err) |
||||
|
} |
||||
|
defer f.Close() |
||||
|
ini, err := internal.ParseINI(f) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to parse SDK properties %q: %v", propertiesPath, err) |
||||
|
} |
||||
|
core, ok := ini["core"] |
||||
|
if !ok { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to find [core] section in %v", ini) |
||||
|
} |
||||
|
active, ok := core["account"] |
||||
|
if !ok { |
||||
|
return nil, fmt.Errorf("oauth2/google: failed to find %q attribute in %v", "account", core) |
||||
|
} |
||||
|
account = active |
||||
|
} |
||||
|
|
||||
|
for _, d := range c.Data { |
||||
|
if account == "" || d.Key.Account == account { |
||||
|
if d.Credential.AccessToken == "" && d.Credential.RefreshToken == "" { |
||||
|
return nil, fmt.Errorf("oauth2/google: no token available for account %q", account) |
||||
|
} |
||||
|
var expiry time.Time |
||||
|
if d.Credential.TokenExpiry != nil { |
||||
|
expiry = *d.Credential.TokenExpiry |
||||
|
} |
||||
|
return &SDKConfig{ |
||||
|
conf: oauth2.Config{ |
||||
|
ClientID: d.Credential.ClientID, |
||||
|
ClientSecret: d.Credential.ClientSecret, |
||||
|
Scopes: strings.Split(d.Key.Scope, " "), |
||||
|
Endpoint: Endpoint, |
||||
|
RedirectURL: "oob", |
||||
|
}, |
||||
|
initialToken: &oauth2.Token{ |
||||
|
AccessToken: d.Credential.AccessToken, |
||||
|
RefreshToken: d.Credential.RefreshToken, |
||||
|
Expiry: expiry, |
||||
|
}, |
||||
|
}, nil |
||||
|
} |
||||
|
} |
||||
|
return nil, fmt.Errorf("oauth2/google: no such credentials for account %q", account) |
||||
|
} |
||||
|
|
||||
|
// Client returns an HTTP client using Google Cloud SDK credentials to
|
||||
|
// authorize requests. The token will auto-refresh as necessary. The
|
||||
|
// underlying http.RoundTripper will be obtained using the provided
|
||||
|
// context. The returned client and its Transport should not be
|
||||
|
// modified.
|
||||
|
func (c *SDKConfig) Client(ctx context.Context) *http.Client { |
||||
|
return &http.Client{ |
||||
|
Transport: &oauth2.Transport{ |
||||
|
Source: c.TokenSource(ctx), |
||||
|
}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// TokenSource returns an oauth2.TokenSource that retrieve tokens from
|
||||
|
// Google Cloud SDK credentials using the provided context.
|
||||
|
// It will returns the current access token stored in the credentials,
|
||||
|
// and refresh it when it expires, but it won't update the credentials
|
||||
|
// with the new access token.
|
||||
|
func (c *SDKConfig) TokenSource(ctx context.Context) oauth2.TokenSource { |
||||
|
return c.conf.TokenSource(ctx, c.initialToken) |
||||
|
} |
||||
|
|
||||
|
// Scopes are the OAuth 2.0 scopes the current account is authorized for.
|
||||
|
func (c *SDKConfig) Scopes() []string { |
||||
|
return c.conf.Scopes |
||||
|
} |
||||
|
|
||||
|
// sdkConfigPath tries to guess where the gcloud config is located.
|
||||
|
// It can be overridden during tests.
|
||||
|
var sdkConfigPath = func() (string, error) { |
||||
|
if runtime.GOOS == "windows" { |
||||
|
return filepath.Join(os.Getenv("APPDATA"), "gcloud"), nil |
||||
|
} |
||||
|
homeDir := guessUnixHomeDir() |
||||
|
if homeDir == "" { |
||||
|
return "", errors.New("unable to get current user home directory: os/user lookup failed; $HOME is empty") |
||||
|
} |
||||
|
return filepath.Join(homeDir, ".config", "gcloud"), nil |
||||
|
} |
||||
|
|
||||
|
func guessUnixHomeDir() string { |
||||
|
usr, err := user.Current() |
||||
|
if err == nil { |
||||
|
return usr.HomeDir |
||||
|
} |
||||
|
return os.Getenv("HOME") |
||||
|
} |
@ -0,0 +1,46 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package google |
||||
|
|
||||
|
import "testing" |
||||
|
|
||||
|
func TestSDKConfig(t *testing.T) { |
||||
|
sdkConfigPath = func() (string, error) { |
||||
|
return "testdata/gcloud", nil |
||||
|
} |
||||
|
|
||||
|
tests := []struct { |
||||
|
account string |
||||
|
accessToken string |
||||
|
err bool |
||||
|
}{ |
||||
|
{"", "bar_access_token", false}, |
||||
|
{"foo@example.com", "foo_access_token", false}, |
||||
|
{"bar@example.com", "bar_access_token", false}, |
||||
|
{"baz@serviceaccount.example.com", "", true}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
c, err := NewSDKConfig(tt.account) |
||||
|
if got, want := err != nil, tt.err; got != want { |
||||
|
if !tt.err { |
||||
|
t.Errorf("got %v, want nil", err) |
||||
|
} else { |
||||
|
t.Errorf("got nil, want error") |
||||
|
} |
||||
|
continue |
||||
|
} |
||||
|
if err != nil { |
||||
|
continue |
||||
|
} |
||||
|
tok := c.initialToken |
||||
|
if tok == nil { |
||||
|
t.Errorf("got nil, want %q", tt.accessToken) |
||||
|
continue |
||||
|
} |
||||
|
if tok.AccessToken != tt.accessToken { |
||||
|
t.Errorf("got %q, want %q", tok.AccessToken, tt.accessToken) |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,122 @@ |
|||||
|
{ |
||||
|
"data": [ |
||||
|
{ |
||||
|
"credential": { |
||||
|
"_class": "OAuth2Credentials", |
||||
|
"_module": "oauth2client.client", |
||||
|
"access_token": "foo_access_token", |
||||
|
"client_id": "foo_client_id", |
||||
|
"client_secret": "foo_client_secret", |
||||
|
"id_token": { |
||||
|
"at_hash": "foo_at_hash", |
||||
|
"aud": "foo_aud", |
||||
|
"azp": "foo_azp", |
||||
|
"cid": "foo_cid", |
||||
|
"email": "foo@example.com", |
||||
|
"email_verified": true, |
||||
|
"exp": 1420573614, |
||||
|
"iat": 1420569714, |
||||
|
"id": "1337", |
||||
|
"iss": "accounts.google.com", |
||||
|
"sub": "1337", |
||||
|
"token_hash": "foo_token_hash", |
||||
|
"verified_email": true |
||||
|
}, |
||||
|
"invalid": false, |
||||
|
"refresh_token": "foo_refresh_token", |
||||
|
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
||||
|
"token_expiry": "2015-01-09T00:51:51Z", |
||||
|
"token_response": { |
||||
|
"access_token": "foo_access_token", |
||||
|
"expires_in": 3600, |
||||
|
"id_token": "foo_id_token", |
||||
|
"token_type": "Bearer" |
||||
|
}, |
||||
|
"token_uri": "https://accounts.google.com/o/oauth2/token", |
||||
|
"user_agent": "Cloud SDK Command Line Tool" |
||||
|
}, |
||||
|
"key": { |
||||
|
"account": "foo@example.com", |
||||
|
"clientId": "foo_client_id", |
||||
|
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
||||
|
"type": "google-cloud-sdk" |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"credential": { |
||||
|
"_class": "OAuth2Credentials", |
||||
|
"_module": "oauth2client.client", |
||||
|
"access_token": "bar_access_token", |
||||
|
"client_id": "bar_client_id", |
||||
|
"client_secret": "bar_client_secret", |
||||
|
"id_token": { |
||||
|
"at_hash": "bar_at_hash", |
||||
|
"aud": "bar_aud", |
||||
|
"azp": "bar_azp", |
||||
|
"cid": "bar_cid", |
||||
|
"email": "bar@example.com", |
||||
|
"email_verified": true, |
||||
|
"exp": 1420573614, |
||||
|
"iat": 1420569714, |
||||
|
"id": "1337", |
||||
|
"iss": "accounts.google.com", |
||||
|
"sub": "1337", |
||||
|
"token_hash": "bar_token_hash", |
||||
|
"verified_email": true |
||||
|
}, |
||||
|
"invalid": false, |
||||
|
"refresh_token": "bar_refresh_token", |
||||
|
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
||||
|
"token_expiry": "2015-01-09T00:51:51Z", |
||||
|
"token_response": { |
||||
|
"access_token": "bar_access_token", |
||||
|
"expires_in": 3600, |
||||
|
"id_token": "bar_id_token", |
||||
|
"token_type": "Bearer" |
||||
|
}, |
||||
|
"token_uri": "https://accounts.google.com/o/oauth2/token", |
||||
|
"user_agent": "Cloud SDK Command Line Tool" |
||||
|
}, |
||||
|
"key": { |
||||
|
"account": "bar@example.com", |
||||
|
"clientId": "bar_client_id", |
||||
|
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
||||
|
"type": "google-cloud-sdk" |
||||
|
} |
||||
|
}, |
||||
|
{ |
||||
|
"credential": { |
||||
|
"_class": "ServiceAccountCredentials", |
||||
|
"_kwargs": {}, |
||||
|
"_module": "oauth2client.client", |
||||
|
"_private_key_id": "00000000000000000000000000000000", |
||||
|
"_private_key_pkcs8_text": "-----BEGIN RSA PRIVATE KEY-----\nMIICWwIBAAKBgQCt3fpiynPSaUhWSIKMGV331zudwJ6GkGmvQtwsoK2S2LbvnSwU\nNxgj4fp08kIDR5p26wF4+t/HrKydMwzftXBfZ9UmLVJgRdSswmS5SmChCrfDS5OE\nvFFcN5+6w1w8/Nu657PF/dse8T0bV95YrqyoR0Osy8WHrUOMSIIbC3hRuwIDAQAB\nAoGAJrGE/KFjn0sQ7yrZ6sXmdLawrM3mObo/2uI9T60+k7SpGbBX0/Pi6nFrJMWZ\nTVONG7P3Mu5aCPzzuVRYJB0j8aldSfzABTY3HKoWCczqw1OztJiEseXGiYz4QOyr\nYU3qDyEpdhS6q6wcoLKGH+hqRmz6pcSEsc8XzOOu7s4xW8kCQQDkc75HjhbarCnd\nJJGMe3U76+6UGmdK67ltZj6k6xoB5WbTNChY9TAyI2JC+ppYV89zv3ssj4L+02u3\nHIHFGxsHAkEAwtU1qYb1tScpchPobnYUFiVKJ7KA8EZaHVaJJODW/cghTCV7BxcJ\nbgVvlmk4lFKn3lPKAgWw7PdQsBTVBUcCrQJATPwoIirizrv3u5soJUQxZIkENAqV\nxmybZx9uetrzP7JTrVbFRf0SScMcyN90hdLJiQL8+i4+gaszgFht7sNMnwJAAbfj\nq0UXcauQwALQ7/h2oONfTg5S+MuGC/AxcXPSMZbMRGGoPh3D5YaCv27aIuS/ukQ+\n6dmm/9AGlCb64fsIWQJAPaokbjIifo+LwC5gyK73Mc4t8nAOSZDenzd/2f6TCq76\nS1dcnKiPxaED7W/y6LJiuBT2rbZiQ2L93NJpFZD/UA==\n-----END RSA PRIVATE KEY-----\n", |
||||
|
"_revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
||||
|
"_scopes": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
||||
|
"_service_account_email": "baz@serviceaccount.example.com", |
||||
|
"_service_account_id": "baz.serviceaccount.example.com", |
||||
|
"_token_uri": "https://accounts.google.com/o/oauth2/token", |
||||
|
"_user_agent": "Cloud SDK Command Line Tool", |
||||
|
"access_token": null, |
||||
|
"assertion_type": null, |
||||
|
"client_id": null, |
||||
|
"client_secret": null, |
||||
|
"id_token": null, |
||||
|
"invalid": false, |
||||
|
"refresh_token": null, |
||||
|
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
||||
|
"service_account_name": "baz@serviceaccount.example.com", |
||||
|
"token_expiry": null, |
||||
|
"token_response": null, |
||||
|
"user_agent": "Cloud SDK Command Line Tool" |
||||
|
}, |
||||
|
"key": { |
||||
|
"account": "baz@serviceaccount.example.com", |
||||
|
"clientId": "baz_client_id", |
||||
|
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
||||
|
"type": "google-cloud-sdk" |
||||
|
} |
||||
|
} |
||||
|
], |
||||
|
"file_version": 1 |
||||
|
} |
@ -0,0 +1,2 @@ |
|||||
|
[core] |
||||
|
account = bar@example.com |
@ -0,0 +1,60 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package hipchat provides constants for using OAuth2 to access HipChat.
|
||||
|
package hipchat // import "golang.org/x/oauth2/hipchat"
|
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/clientcredentials" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is HipChat's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.hipchat.com/users/authorize", |
||||
|
TokenURL: "https://api.hipchat.com/v2/oauth/token", |
||||
|
} |
||||
|
|
||||
|
// ServerEndpoint returns a new oauth2.Endpoint for a HipChat Server instance
|
||||
|
// running on the given domain or host.
|
||||
|
func ServerEndpoint(host string) oauth2.Endpoint { |
||||
|
return oauth2.Endpoint{ |
||||
|
AuthURL: "https://" + host + "/users/authorize", |
||||
|
TokenURL: "https://" + host + "/v2/oauth/token", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// ClientCredentialsConfigFromCaps generates a Config from a HipChat API
|
||||
|
// capabilities descriptor. It does not verify the scopes against the
|
||||
|
// capabilities document at this time.
|
||||
|
//
|
||||
|
// For more information see: https://www.hipchat.com/docs/apiv2/method/get_capabilities
|
||||
|
func ClientCredentialsConfigFromCaps(capsJSON []byte, clientID, clientSecret string, scopes ...string) (*clientcredentials.Config, error) { |
||||
|
var caps struct { |
||||
|
Caps struct { |
||||
|
Endpoint struct { |
||||
|
TokenURL string `json:"tokenUrl"` |
||||
|
} `json:"oauth2Provider"` |
||||
|
} `json:"capabilities"` |
||||
|
} |
||||
|
|
||||
|
if err := json.Unmarshal(capsJSON, &caps); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
// Verify required fields.
|
||||
|
if caps.Caps.Endpoint.TokenURL == "" { |
||||
|
return nil, errors.New("oauth2/hipchat: missing OAuth2 token URL in the capabilities descriptor JSON") |
||||
|
} |
||||
|
|
||||
|
return &clientcredentials.Config{ |
||||
|
ClientID: clientID, |
||||
|
ClientSecret: clientSecret, |
||||
|
Scopes: scopes, |
||||
|
TokenURL: caps.Caps.Endpoint.TokenURL, |
||||
|
}, nil |
||||
|
} |
@ -0,0 +1,76 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package internal contains support packages for oauth2 package.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"bufio" |
||||
|
"crypto/rsa" |
||||
|
"crypto/x509" |
||||
|
"encoding/pem" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// ParseKey converts the binary contents of a private key file
|
||||
|
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
||||
|
// PEM container or not. If so, it extracts the the private key
|
||||
|
// from PEM container before conversion. It only supports PEM
|
||||
|
// containers with no passphrase.
|
||||
|
func ParseKey(key []byte) (*rsa.PrivateKey, error) { |
||||
|
block, _ := pem.Decode(key) |
||||
|
if block != nil { |
||||
|
key = block.Bytes |
||||
|
} |
||||
|
parsedKey, err := x509.ParsePKCS8PrivateKey(key) |
||||
|
if err != nil { |
||||
|
parsedKey, err = x509.ParsePKCS1PrivateKey(key) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) |
||||
|
} |
||||
|
} |
||||
|
parsed, ok := parsedKey.(*rsa.PrivateKey) |
||||
|
if !ok { |
||||
|
return nil, errors.New("private key is invalid") |
||||
|
} |
||||
|
return parsed, nil |
||||
|
} |
||||
|
|
||||
|
func ParseINI(ini io.Reader) (map[string]map[string]string, error) { |
||||
|
result := map[string]map[string]string{ |
||||
|
"": map[string]string{}, // root section
|
||||
|
} |
||||
|
scanner := bufio.NewScanner(ini) |
||||
|
currentSection := "" |
||||
|
for scanner.Scan() { |
||||
|
line := strings.TrimSpace(scanner.Text()) |
||||
|
if strings.HasPrefix(line, ";") { |
||||
|
// comment.
|
||||
|
continue |
||||
|
} |
||||
|
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { |
||||
|
currentSection = strings.TrimSpace(line[1 : len(line)-1]) |
||||
|
result[currentSection] = map[string]string{} |
||||
|
continue |
||||
|
} |
||||
|
parts := strings.SplitN(line, "=", 2) |
||||
|
if len(parts) == 2 && parts[0] != "" { |
||||
|
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) |
||||
|
} |
||||
|
} |
||||
|
if err := scanner.Err(); err != nil { |
||||
|
return nil, fmt.Errorf("error scanning ini: %v", err) |
||||
|
} |
||||
|
return result, nil |
||||
|
} |
||||
|
|
||||
|
func CondVal(v string) []string { |
||||
|
if v == "" { |
||||
|
return nil |
||||
|
} |
||||
|
return []string{v} |
||||
|
} |
@ -0,0 +1,62 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package internal contains support packages for oauth2 package.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"reflect" |
||||
|
"strings" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestParseINI(t *testing.T) { |
||||
|
tests := []struct { |
||||
|
ini string |
||||
|
want map[string]map[string]string |
||||
|
}{ |
||||
|
{ |
||||
|
`root = toor |
||||
|
[foo] |
||||
|
bar = hop |
||||
|
ini = nin |
||||
|
`, |
||||
|
map[string]map[string]string{ |
||||
|
"": map[string]string{"root": "toor"}, |
||||
|
"foo": map[string]string{"bar": "hop", "ini": "nin"}, |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
`[empty] |
||||
|
[section] |
||||
|
empty= |
||||
|
`, |
||||
|
map[string]map[string]string{ |
||||
|
"": map[string]string{}, |
||||
|
"empty": map[string]string{}, |
||||
|
"section": map[string]string{"empty": ""}, |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
`ignore |
||||
|
[invalid |
||||
|
=stuff |
||||
|
;comment=true |
||||
|
`, |
||||
|
map[string]map[string]string{ |
||||
|
"": map[string]string{}, |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
for _, tt := range tests { |
||||
|
result, err := ParseINI(strings.NewReader(tt.ini)) |
||||
|
if err != nil { |
||||
|
t.Errorf("ParseINI(%q) error %v, want: no error", tt.ini, err) |
||||
|
continue |
||||
|
} |
||||
|
if !reflect.DeepEqual(result, tt.want) { |
||||
|
t.Errorf("ParseINI(%q) = %#v, want: %#v", tt.ini, result, tt.want) |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,225 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package internal contains support packages for oauth2 package.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"io/ioutil" |
||||
|
"mime" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
// Token represents the crendentials used to authorize
|
||||
|
// the requests to access protected resources on the OAuth 2.0
|
||||
|
// provider's backend.
|
||||
|
//
|
||||
|
// This type is a mirror of oauth2.Token and exists to break
|
||||
|
// an otherwise-circular dependency. Other internal packages
|
||||
|
// should convert this Token into an oauth2.Token before use.
|
||||
|
type Token struct { |
||||
|
// AccessToken is the token that authorizes and authenticates
|
||||
|
// the requests.
|
||||
|
AccessToken string |
||||
|
|
||||
|
// TokenType is the type of token.
|
||||
|
// The Type method returns either this or "Bearer", the default.
|
||||
|
TokenType string |
||||
|
|
||||
|
// RefreshToken is a token that's used by the application
|
||||
|
// (as opposed to the user) to refresh the access token
|
||||
|
// if it expires.
|
||||
|
RefreshToken string |
||||
|
|
||||
|
// Expiry is the optional expiration time of the access token.
|
||||
|
//
|
||||
|
// If zero, TokenSource implementations will reuse the same
|
||||
|
// token forever and RefreshToken or equivalent
|
||||
|
// mechanisms for that TokenSource will not be used.
|
||||
|
Expiry time.Time |
||||
|
|
||||
|
// Raw optionally contains extra metadata from the server
|
||||
|
// when updating a token.
|
||||
|
Raw interface{} |
||||
|
} |
||||
|
|
||||
|
// tokenJSON is the struct representing the HTTP response from OAuth2
|
||||
|
// providers returning a token in JSON form.
|
||||
|
type tokenJSON struct { |
||||
|
AccessToken string `json:"access_token"` |
||||
|
TokenType string `json:"token_type"` |
||||
|
RefreshToken string `json:"refresh_token"` |
||||
|
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
|
||||
|
Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in
|
||||
|
} |
||||
|
|
||||
|
func (e *tokenJSON) expiry() (t time.Time) { |
||||
|
if v := e.ExpiresIn; v != 0 { |
||||
|
return time.Now().Add(time.Duration(v) * time.Second) |
||||
|
} |
||||
|
if v := e.Expires; v != 0 { |
||||
|
return time.Now().Add(time.Duration(v) * time.Second) |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
type expirationTime int32 |
||||
|
|
||||
|
func (e *expirationTime) UnmarshalJSON(b []byte) error { |
||||
|
var n json.Number |
||||
|
err := json.Unmarshal(b, &n) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
i, err := n.Int64() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
*e = expirationTime(i) |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
var brokenAuthHeaderProviders = []string{ |
||||
|
"https://accounts.google.com/", |
||||
|
"https://api.dropbox.com/", |
||||
|
"https://api.dropboxapi.com/", |
||||
|
"https://api.instagram.com/", |
||||
|
"https://api.netatmo.net/", |
||||
|
"https://api.odnoklassniki.ru/", |
||||
|
"https://api.pushbullet.com/", |
||||
|
"https://api.soundcloud.com/", |
||||
|
"https://api.twitch.tv/", |
||||
|
"https://app.box.com/", |
||||
|
"https://connect.stripe.com/", |
||||
|
"https://login.microsoftonline.com/", |
||||
|
"https://login.salesforce.com/", |
||||
|
"https://oauth.sandbox.trainingpeaks.com/", |
||||
|
"https://oauth.trainingpeaks.com/", |
||||
|
"https://oauth.vk.com/", |
||||
|
"https://openapi.baidu.com/", |
||||
|
"https://slack.com/", |
||||
|
"https://test-sandbox.auth.corp.google.com", |
||||
|
"https://test.salesforce.com/", |
||||
|
"https://user.gini.net/", |
||||
|
"https://www.douban.com/", |
||||
|
"https://www.googleapis.com/", |
||||
|
"https://www.linkedin.com/", |
||||
|
"https://www.strava.com/oauth/", |
||||
|
"https://www.wunderlist.com/oauth/", |
||||
|
"https://api.patreon.com/", |
||||
|
} |
||||
|
|
||||
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) { |
||||
|
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL) |
||||
|
} |
||||
|
|
||||
|
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
||||
|
// implements the OAuth2 spec correctly
|
||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||
|
// In summary:
|
||||
|
// - Reddit only accepts client secret in the Authorization header
|
||||
|
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||
|
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||
|
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
||||
|
func providerAuthHeaderWorks(tokenURL string) bool { |
||||
|
for _, s := range brokenAuthHeaderProviders { |
||||
|
if strings.HasPrefix(tokenURL, s) { |
||||
|
// Some sites fail to implement the OAuth2 spec fully.
|
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Assume the provider implements the spec properly
|
||||
|
// otherwise. We can add more exceptions as they're
|
||||
|
// discovered. We will _not_ be adding configurable hooks
|
||||
|
// to this package to let users select server bugs.
|
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { |
||||
|
hc, err := ContextClient(ctx) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
v.Set("client_id", clientID) |
||||
|
bustedAuth := !providerAuthHeaderWorks(tokenURL) |
||||
|
if bustedAuth && clientSecret != "" { |
||||
|
v.Set("client_secret", clientSecret) |
||||
|
} |
||||
|
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
if !bustedAuth { |
||||
|
req.SetBasicAuth(clientID, clientSecret) |
||||
|
} |
||||
|
r, err := hc.Do(req) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
defer r.Body.Close() |
||||
|
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
||||
|
} |
||||
|
if code := r.StatusCode; code < 200 || code > 299 { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) |
||||
|
} |
||||
|
|
||||
|
var token *Token |
||||
|
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) |
||||
|
switch content { |
||||
|
case "application/x-www-form-urlencoded", "text/plain": |
||||
|
vals, err := url.ParseQuery(string(body)) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
token = &Token{ |
||||
|
AccessToken: vals.Get("access_token"), |
||||
|
TokenType: vals.Get("token_type"), |
||||
|
RefreshToken: vals.Get("refresh_token"), |
||||
|
Raw: vals, |
||||
|
} |
||||
|
e := vals.Get("expires_in") |
||||
|
if e == "" { |
||||
|
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
||||
|
// returns expires_in field in expires. Remove the fallback to expires,
|
||||
|
// when Facebook fixes their implementation.
|
||||
|
e = vals.Get("expires") |
||||
|
} |
||||
|
expires, _ := strconv.Atoi(e) |
||||
|
if expires != 0 { |
||||
|
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) |
||||
|
} |
||||
|
default: |
||||
|
var tj tokenJSON |
||||
|
if err = json.Unmarshal(body, &tj); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
token = &Token{ |
||||
|
AccessToken: tj.AccessToken, |
||||
|
TokenType: tj.TokenType, |
||||
|
RefreshToken: tj.RefreshToken, |
||||
|
Expiry: tj.expiry(), |
||||
|
Raw: make(map[string]interface{}), |
||||
|
} |
||||
|
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
||||
|
} |
||||
|
// Don't overwrite `RefreshToken` with an empty value
|
||||
|
// if this was a token refreshing request.
|
||||
|
if token.RefreshToken == "" { |
||||
|
token.RefreshToken = v.Get("refresh_token") |
||||
|
} |
||||
|
return token, nil |
||||
|
} |
@ -0,0 +1,36 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package internal contains support packages for oauth2 package.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestRegisterBrokenAuthHeaderProvider(t *testing.T) { |
||||
|
RegisterBrokenAuthHeaderProvider("https://aaa.com/") |
||||
|
tokenURL := "https://aaa.com/token" |
||||
|
if providerAuthHeaderWorks(tokenURL) { |
||||
|
t.Errorf("URL: %s is a broken provider", tokenURL) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func Test_providerAuthHeaderWorks(t *testing.T) { |
||||
|
for _, p := range brokenAuthHeaderProviders { |
||||
|
if providerAuthHeaderWorks(p) { |
||||
|
t.Errorf("URL: %s not found in list", p) |
||||
|
} |
||||
|
p := fmt.Sprintf("%ssomesuffix", p) |
||||
|
if providerAuthHeaderWorks(p) { |
||||
|
t.Errorf("URL: %s not found in list", p) |
||||
|
} |
||||
|
} |
||||
|
p := "https://api.not-in-the-list-example.com/" |
||||
|
if !providerAuthHeaderWorks(p) { |
||||
|
t.Errorf("URL: %s found in list", p) |
||||
|
} |
||||
|
|
||||
|
} |
@ -0,0 +1,69 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package internal contains support packages for oauth2 package.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||
|
// WithValue function to associate an *http.Client value with a context.
|
||||
|
var HTTPClient ContextKey |
||||
|
|
||||
|
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
||||
|
// an immutable public variable with a unique type. It's immutable
|
||||
|
// because nobody else can create a ContextKey, being unexported.
|
||||
|
type ContextKey struct{} |
||||
|
|
||||
|
// ContextClientFunc is a func which tries to return an *http.Client
|
||||
|
// given a Context value. If it returns an error, the search stops
|
||||
|
// with that error. If it returns (nil, nil), the search continues
|
||||
|
// down the list of registered funcs.
|
||||
|
type ContextClientFunc func(context.Context) (*http.Client, error) |
||||
|
|
||||
|
var contextClientFuncs []ContextClientFunc |
||||
|
|
||||
|
func RegisterContextClientFunc(fn ContextClientFunc) { |
||||
|
contextClientFuncs = append(contextClientFuncs, fn) |
||||
|
} |
||||
|
|
||||
|
func ContextClient(ctx context.Context) (*http.Client, error) { |
||||
|
if ctx != nil { |
||||
|
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { |
||||
|
return hc, nil |
||||
|
} |
||||
|
} |
||||
|
for _, fn := range contextClientFuncs { |
||||
|
c, err := fn(ctx) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if c != nil { |
||||
|
return c, nil |
||||
|
} |
||||
|
} |
||||
|
return http.DefaultClient, nil |
||||
|
} |
||||
|
|
||||
|
func ContextTransport(ctx context.Context) http.RoundTripper { |
||||
|
hc, err := ContextClient(ctx) |
||||
|
// This is a rare error case (somebody using nil on App Engine).
|
||||
|
if err != nil { |
||||
|
return ErrorTransport{err} |
||||
|
} |
||||
|
return hc.Transport |
||||
|
} |
||||
|
|
||||
|
// ErrorTransport returns the specified error on RoundTrip.
|
||||
|
// This RoundTripper should be used in rare error cases where
|
||||
|
// error handling can be postponed to response handling time.
|
||||
|
type ErrorTransport struct{ Err error } |
||||
|
|
||||
|
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { |
||||
|
return nil, t.Err |
||||
|
} |
@ -0,0 +1,38 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
"testing" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
func TestContextClient(t *testing.T) { |
||||
|
rc := &http.Client{} |
||||
|
RegisterContextClientFunc(func(context.Context) (*http.Client, error) { |
||||
|
return rc, nil |
||||
|
}) |
||||
|
|
||||
|
c := &http.Client{} |
||||
|
ctx := context.WithValue(context.Background(), HTTPClient, c) |
||||
|
|
||||
|
hc, err := ContextClient(ctx) |
||||
|
if err != nil { |
||||
|
t.Fatalf("want valid client; got err = %v", err) |
||||
|
} |
||||
|
if hc != c { |
||||
|
t.Fatalf("want context client = %p; got = %p", c, hc) |
||||
|
} |
||||
|
|
||||
|
hc, err = ContextClient(context.TODO()) |
||||
|
if err != nil { |
||||
|
t.Fatalf("want valid client; got err = %v", err) |
||||
|
} |
||||
|
if hc != rc { |
||||
|
t.Fatalf("want registered client = %p; got = %p", c, hc) |
||||
|
} |
||||
|
} |
@ -0,0 +1,174 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package jws provides encoding and decoding utilities for
|
||||
|
// signed JWS messages.
|
||||
|
package jws // import "golang.org/x/oauth2/jws"
|
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"crypto" |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"crypto/sha256" |
||||
|
"encoding/base64" |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"strings" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// ClaimSet contains information about the JWT signature including the
|
||||
|
// permissions being requested (scopes), the target of the token, the issuer,
|
||||
|
// the time the token was issued, and the lifetime of the token.
|
||||
|
type ClaimSet struct { |
||||
|
Iss string `json:"iss"` // email address of the client_id of the application making the access token request
|
||||
|
Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests
|
||||
|
Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional).
|
||||
|
Exp int64 `json:"exp"` // the expiration time of the assertion (seconds since Unix epoch)
|
||||
|
Iat int64 `json:"iat"` // the time the assertion was issued (seconds since Unix epoch)
|
||||
|
Typ string `json:"typ,omitempty"` // token type (Optional).
|
||||
|
|
||||
|
// Email for which the application is requesting delegated access (Optional).
|
||||
|
Sub string `json:"sub,omitempty"` |
||||
|
|
||||
|
// The old name of Sub. Client keeps setting Prn to be
|
||||
|
// complaint with legacy OAuth 2.0 providers. (Optional)
|
||||
|
Prn string `json:"prn,omitempty"` |
||||
|
|
||||
|
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
||||
|
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
|
||||
|
PrivateClaims map[string]interface{} `json:"-"` |
||||
|
} |
||||
|
|
||||
|
func (c *ClaimSet) encode() (string, error) { |
||||
|
// Reverting time back for machines whose time is not perfectly in sync.
|
||||
|
// If client machine's time is in the future according
|
||||
|
// to Google servers, an access token will not be issued.
|
||||
|
now := time.Now().Add(-10 * time.Second) |
||||
|
if c.Iat == 0 { |
||||
|
c.Iat = now.Unix() |
||||
|
} |
||||
|
if c.Exp == 0 { |
||||
|
c.Exp = now.Add(time.Hour).Unix() |
||||
|
} |
||||
|
if c.Exp < c.Iat { |
||||
|
return "", fmt.Errorf("jws: invalid Exp = %v; must be later than Iat = %v", c.Exp, c.Iat) |
||||
|
} |
||||
|
|
||||
|
b, err := json.Marshal(c) |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
|
||||
|
if len(c.PrivateClaims) == 0 { |
||||
|
return base64.RawURLEncoding.EncodeToString(b), nil |
||||
|
} |
||||
|
|
||||
|
// Marshal private claim set and then append it to b.
|
||||
|
prv, err := json.Marshal(c.PrivateClaims) |
||||
|
if err != nil { |
||||
|
return "", fmt.Errorf("jws: invalid map of private claims %v", c.PrivateClaims) |
||||
|
} |
||||
|
|
||||
|
// Concatenate public and private claim JSON objects.
|
||||
|
if !bytes.HasSuffix(b, []byte{'}'}) { |
||||
|
return "", fmt.Errorf("jws: invalid JSON %s", b) |
||||
|
} |
||||
|
if !bytes.HasPrefix(prv, []byte{'{'}) { |
||||
|
return "", fmt.Errorf("jws: invalid JSON %s", prv) |
||||
|
} |
||||
|
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
|
||||
|
b = append(b, prv[1:]...) // Append private claims.
|
||||
|
return base64.RawURLEncoding.EncodeToString(b), nil |
||||
|
} |
||||
|
|
||||
|
// Header represents the header for the signed JWS payloads.
|
||||
|
type Header struct { |
||||
|
// The algorithm used for signature.
|
||||
|
Algorithm string `json:"alg"` |
||||
|
|
||||
|
// Represents the token type.
|
||||
|
Typ string `json:"typ"` |
||||
|
|
||||
|
// The optional hint of which key is being used.
|
||||
|
KeyID string `json:"kid,omitempty"` |
||||
|
} |
||||
|
|
||||
|
func (h *Header) encode() (string, error) { |
||||
|
b, err := json.Marshal(h) |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
return base64.RawURLEncoding.EncodeToString(b), nil |
||||
|
} |
||||
|
|
||||
|
// Decode decodes a claim set from a JWS payload.
|
||||
|
func Decode(payload string) (*ClaimSet, error) { |
||||
|
// decode returned id token to get expiry
|
||||
|
s := strings.Split(payload, ".") |
||||
|
if len(s) < 2 { |
||||
|
// TODO(jbd): Provide more context about the error.
|
||||
|
return nil, errors.New("jws: invalid token received") |
||||
|
} |
||||
|
decoded, err := base64.RawURLEncoding.DecodeString(s[1]) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
c := &ClaimSet{} |
||||
|
err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c) |
||||
|
return c, err |
||||
|
} |
||||
|
|
||||
|
// Signer returns a signature for the given data.
|
||||
|
type Signer func(data []byte) (sig []byte, err error) |
||||
|
|
||||
|
// EncodeWithSigner encodes a header and claim set with the provided signer.
|
||||
|
func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) { |
||||
|
head, err := header.encode() |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
cs, err := c.encode() |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
ss := fmt.Sprintf("%s.%s", head, cs) |
||||
|
sig, err := sg([]byte(ss)) |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil |
||||
|
} |
||||
|
|
||||
|
// Encode encodes a signed JWS with provided header and claim set.
|
||||
|
// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
|
||||
|
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { |
||||
|
sg := func(data []byte) (sig []byte, err error) { |
||||
|
h := sha256.New() |
||||
|
h.Write(data) |
||||
|
return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil)) |
||||
|
} |
||||
|
return EncodeWithSigner(header, c, sg) |
||||
|
} |
||||
|
|
||||
|
// Verify tests whether the provided JWT token's signature was produced by the private key
|
||||
|
// associated with the supplied public key.
|
||||
|
func Verify(token string, key *rsa.PublicKey) error { |
||||
|
parts := strings.Split(token, ".") |
||||
|
if len(parts) != 3 { |
||||
|
return errors.New("jws: invalid token received, token must have 3 parts") |
||||
|
} |
||||
|
|
||||
|
signedContent := parts[0] + "." + parts[1] |
||||
|
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
h := sha256.New() |
||||
|
h.Write([]byte(signedContent)) |
||||
|
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), []byte(signatureString)) |
||||
|
} |
@ -0,0 +1,46 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package jws |
||||
|
|
||||
|
import ( |
||||
|
"crypto/rand" |
||||
|
"crypto/rsa" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestSignAndVerify(t *testing.T) { |
||||
|
header := &Header{ |
||||
|
Algorithm: "RS256", |
||||
|
Typ: "JWT", |
||||
|
} |
||||
|
payload := &ClaimSet{ |
||||
|
Iss: "http://google.com/", |
||||
|
Aud: "", |
||||
|
Exp: 3610, |
||||
|
Iat: 10, |
||||
|
} |
||||
|
|
||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
token, err := Encode(header, payload, privateKey) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
|
||||
|
err = Verify(token, &privateKey.PublicKey) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestVerifyFailsOnMalformedClaim(t *testing.T) { |
||||
|
err := Verify("abc.def", nil) |
||||
|
if err == nil { |
||||
|
t.Error("Improperly formed JWT should fail.") |
||||
|
} |
||||
|
} |
@ -0,0 +1,31 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package jwt_test |
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/jwt" |
||||
|
) |
||||
|
|
||||
|
func ExampleJWTConfig() { |
||||
|
conf := &jwt.Config{ |
||||
|
Email: "xxx@developer.com", |
||||
|
// The contents of your RSA private key or your PEM file
|
||||
|
// that contains a private key.
|
||||
|
// If you have a p12 file instead, you
|
||||
|
// can use `openssl` to export the private key into a pem file.
|
||||
|
//
|
||||
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||
|
//
|
||||
|
// It only supports PEM containers with no passphrase.
|
||||
|
PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), |
||||
|
Subject: "user@example.com", |
||||
|
TokenURL: "https://provider.com/o/oauth2/token", |
||||
|
} |
||||
|
// Initiate an http.Client, the following GET request will be
|
||||
|
// authorized and authenticated on the behalf of user@example.com.
|
||||
|
client := conf.Client(oauth2.NoContext) |
||||
|
client.Get("...") |
||||
|
} |
@ -0,0 +1,157 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package jwt implements the OAuth 2.0 JSON Web Token flow, commonly
|
||||
|
// known as "two-legged OAuth 2.0".
|
||||
|
//
|
||||
|
// See: https://tools.ietf.org/html/draft-ietf-oauth-jwt-bearer-12
|
||||
|
package jwt |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
"golang.org/x/oauth2/jws" |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" |
||||
|
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} |
||||
|
) |
||||
|
|
||||
|
// Config is the configuration for using JWT to fetch tokens,
|
||||
|
// commonly known as "two-legged OAuth 2.0".
|
||||
|
type Config struct { |
||||
|
// Email is the OAuth client identifier used when communicating with
|
||||
|
// the configured OAuth provider.
|
||||
|
Email string |
||||
|
|
||||
|
// PrivateKey contains the contents of an RSA private key or the
|
||||
|
// contents of a PEM file that contains a private key. The provided
|
||||
|
// private key is used to sign JWT payloads.
|
||||
|
// PEM containers with a passphrase are not supported.
|
||||
|
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||
|
//
|
||||
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||
|
//
|
||||
|
PrivateKey []byte |
||||
|
|
||||
|
// PrivateKeyID contains an optional hint indicating which key is being
|
||||
|
// used.
|
||||
|
PrivateKeyID string |
||||
|
|
||||
|
// Subject is the optional user to impersonate.
|
||||
|
Subject string |
||||
|
|
||||
|
// Scopes optionally specifies a list of requested permission scopes.
|
||||
|
Scopes []string |
||||
|
|
||||
|
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
|
||||
|
TokenURL string |
||||
|
|
||||
|
// Expires optionally specifies how long the token is valid for.
|
||||
|
Expires time.Duration |
||||
|
} |
||||
|
|
||||
|
// TokenSource returns a JWT TokenSource using the configuration
|
||||
|
// in c and the HTTP client from the provided context.
|
||||
|
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { |
||||
|
return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c}) |
||||
|
} |
||||
|
|
||||
|
// Client returns an HTTP client wrapping the context's
|
||||
|
// HTTP transport and adding Authorization headers with tokens
|
||||
|
// obtained from c.
|
||||
|
//
|
||||
|
// The returned client and its Transport should not be modified.
|
||||
|
func (c *Config) Client(ctx context.Context) *http.Client { |
||||
|
return oauth2.NewClient(ctx, c.TokenSource(ctx)) |
||||
|
} |
||||
|
|
||||
|
// jwtSource is a source that always does a signed JWT request for a token.
|
||||
|
// It should typically be wrapped with a reuseTokenSource.
|
||||
|
type jwtSource struct { |
||||
|
ctx context.Context |
||||
|
conf *Config |
||||
|
} |
||||
|
|
||||
|
func (js jwtSource) Token() (*oauth2.Token, error) { |
||||
|
pk, err := internal.ParseKey(js.conf.PrivateKey) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
hc := oauth2.NewClient(js.ctx, nil) |
||||
|
claimSet := &jws.ClaimSet{ |
||||
|
Iss: js.conf.Email, |
||||
|
Scope: strings.Join(js.conf.Scopes, " "), |
||||
|
Aud: js.conf.TokenURL, |
||||
|
} |
||||
|
if subject := js.conf.Subject; subject != "" { |
||||
|
claimSet.Sub = subject |
||||
|
// prn is the old name of sub. Keep setting it
|
||||
|
// to be compatible with legacy OAuth 2.0 providers.
|
||||
|
claimSet.Prn = subject |
||||
|
} |
||||
|
if t := js.conf.Expires; t > 0 { |
||||
|
claimSet.Exp = time.Now().Add(t).Unix() |
||||
|
} |
||||
|
payload, err := jws.Encode(defaultHeader, claimSet, pk) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
v := url.Values{} |
||||
|
v.Set("grant_type", defaultGrantType) |
||||
|
v.Set("assertion", payload) |
||||
|
resp, err := hc.PostForm(js.conf.TokenURL, v) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
||||
|
} |
||||
|
defer resp.Body.Close() |
||||
|
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
||||
|
} |
||||
|
if c := resp.StatusCode; c < 200 || c > 299 { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) |
||||
|
} |
||||
|
// tokenRes is the JSON response body.
|
||||
|
var tokenRes struct { |
||||
|
AccessToken string `json:"access_token"` |
||||
|
TokenType string `json:"token_type"` |
||||
|
IDToken string `json:"id_token"` |
||||
|
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
||||
|
} |
||||
|
if err := json.Unmarshal(body, &tokenRes); err != nil { |
||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
||||
|
} |
||||
|
token := &oauth2.Token{ |
||||
|
AccessToken: tokenRes.AccessToken, |
||||
|
TokenType: tokenRes.TokenType, |
||||
|
} |
||||
|
raw := make(map[string]interface{}) |
||||
|
json.Unmarshal(body, &raw) // no error checks for optional fields
|
||||
|
token = token.WithExtra(raw) |
||||
|
|
||||
|
if secs := tokenRes.ExpiresIn; secs > 0 { |
||||
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) |
||||
|
} |
||||
|
if v := tokenRes.IDToken; v != "" { |
||||
|
// decode returned id token to get expiry
|
||||
|
claimSet, err := jws.Decode(v) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err) |
||||
|
} |
||||
|
token.Expiry = time.Unix(claimSet.Exp, 0) |
||||
|
} |
||||
|
return token, nil |
||||
|
} |
@ -0,0 +1,134 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package jwt |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
|
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- |
||||
|
MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE |
||||
|
DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY |
||||
|
fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK |
||||
|
1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr |
||||
|
k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9 |
||||
|
/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt |
||||
|
3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn |
||||
|
2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3 |
||||
|
nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK |
||||
|
6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf |
||||
|
5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e |
||||
|
DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1 |
||||
|
M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g |
||||
|
z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y |
||||
|
1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK |
||||
|
J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U |
||||
|
f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx |
||||
|
QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA |
||||
|
cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr |
||||
|
Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw |
||||
|
5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg |
||||
|
KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84 |
||||
|
OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd |
||||
|
mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ |
||||
|
5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg== |
||||
|
-----END RSA PRIVATE KEY-----`) |
||||
|
|
||||
|
func TestJWTFetch_JSONResponse(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{ |
||||
|
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", |
||||
|
"scope": "user", |
||||
|
"token_type": "bearer", |
||||
|
"expires_in": 3600 |
||||
|
}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
conf := &Config{ |
||||
|
Email: "aaa@xxx.com", |
||||
|
PrivateKey: dummyPrivateKey, |
||||
|
TokenURL: ts.URL, |
||||
|
} |
||||
|
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Errorf("Token invalid") |
||||
|
} |
||||
|
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
||||
|
t.Errorf("Unexpected access token, %#v", tok.AccessToken) |
||||
|
} |
||||
|
if tok.TokenType != "bearer" { |
||||
|
t.Errorf("Unexpected token type, %#v", tok.TokenType) |
||||
|
} |
||||
|
if tok.Expiry.IsZero() { |
||||
|
t.Errorf("Unexpected token expiry, %#v", tok.Expiry) |
||||
|
} |
||||
|
scope := tok.Extra("scope") |
||||
|
if scope != "user" { |
||||
|
t.Errorf("Unexpected value for scope: %v", scope) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestJWTFetch_BadResponse(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
|
||||
|
conf := &Config{ |
||||
|
Email: "aaa@xxx.com", |
||||
|
PrivateKey: dummyPrivateKey, |
||||
|
TokenURL: ts.URL, |
||||
|
} |
||||
|
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if tok == nil { |
||||
|
t.Fatalf("token is nil") |
||||
|
} |
||||
|
if tok.Valid() { |
||||
|
t.Errorf("token is valid. want invalid.") |
||||
|
} |
||||
|
if tok.AccessToken != "" { |
||||
|
t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken) |
||||
|
} |
||||
|
if want := "bearer"; tok.TokenType != want { |
||||
|
t.Errorf("TokenType = %q; want %q", tok.TokenType, want) |
||||
|
} |
||||
|
scope := tok.Extra("scope") |
||||
|
if want := "user"; scope != want { |
||||
|
t.Errorf("token scope = %q; want %q", scope, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestJWTFetch_BadResponseType(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := &Config{ |
||||
|
Email: "aaa@xxx.com", |
||||
|
PrivateKey: dummyPrivateKey, |
||||
|
TokenURL: ts.URL, |
||||
|
} |
||||
|
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
||||
|
if err == nil { |
||||
|
t.Error("got a token; expected error") |
||||
|
if tok.AccessToken != "" { |
||||
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package linkedin provides constants for using OAuth2 to access LinkedIn.
|
||||
|
package linkedin // import "golang.org/x/oauth2/linkedin"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is LinkedIn's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.linkedin.com/uas/oauth2/authorization", |
||||
|
TokenURL: "https://www.linkedin.com/uas/oauth2/accessToken", |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package microsoft provides constants for using OAuth2 to access Windows Live ID.
|
||||
|
package microsoft // import "golang.org/x/oauth2/microsoft"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint.
|
||||
|
var LiveConnectEndpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://login.live.com/oauth20_authorize.srf", |
||||
|
TokenURL: "https://login.live.com/oauth20_token.srf", |
||||
|
} |
@ -0,0 +1,337 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package oauth2 provides support for making
|
||||
|
// OAuth2 authorized and authenticated HTTP requests.
|
||||
|
// It can additionally grant authorization with Bearer JWT.
|
||||
|
package oauth2 // import "golang.org/x/oauth2"
|
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"errors" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
) |
||||
|
|
||||
|
// NoContext is the default context you should supply if not using
|
||||
|
// your own context.Context (see https://golang.org/x/net/context).
|
||||
|
var NoContext = context.TODO() |
||||
|
|
||||
|
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
|
||||
|
// identified by the tokenURL prefix as an OAuth2 implementation
|
||||
|
// which doesn't support the HTTP Basic authentication
|
||||
|
// scheme to authenticate with the authorization server.
|
||||
|
// Once a server is registered, credentials (client_id and client_secret)
|
||||
|
// will be passed as query parameters rather than being present
|
||||
|
// in the Authorization header.
|
||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) { |
||||
|
internal.RegisterBrokenAuthHeaderProvider(tokenURL) |
||||
|
} |
||||
|
|
||||
|
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||
|
// client application information and the server's endpoint URLs.
|
||||
|
type Config struct { |
||||
|
// ClientID is the application's ID.
|
||||
|
ClientID string |
||||
|
|
||||
|
// ClientSecret is the application's secret.
|
||||
|
ClientSecret string |
||||
|
|
||||
|
// Endpoint contains the resource server's token endpoint
|
||||
|
// URLs. These are constants specific to each server and are
|
||||
|
// often available via site-specific packages, such as
|
||||
|
// google.Endpoint or github.Endpoint.
|
||||
|
Endpoint Endpoint |
||||
|
|
||||
|
// RedirectURL is the URL to redirect users going through
|
||||
|
// the OAuth flow, after the resource owner's URLs.
|
||||
|
RedirectURL string |
||||
|
|
||||
|
// Scope specifies optional requested permissions.
|
||||
|
Scopes []string |
||||
|
} |
||||
|
|
||||
|
// A TokenSource is anything that can return a token.
|
||||
|
type TokenSource interface { |
||||
|
// Token returns a token or an error.
|
||||
|
// Token must be safe for concurrent use by multiple goroutines.
|
||||
|
// The returned Token must not be modified.
|
||||
|
Token() (*Token, error) |
||||
|
} |
||||
|
|
||||
|
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
||||
|
// endpoint URLs.
|
||||
|
type Endpoint struct { |
||||
|
AuthURL string |
||||
|
TokenURL string |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
// AccessTypeOnline and AccessTypeOffline are options passed
|
||||
|
// to the Options.AuthCodeURL method. They modify the
|
||||
|
// "access_type" field that gets sent in the URL returned by
|
||||
|
// AuthCodeURL.
|
||||
|
//
|
||||
|
// Online is the default if neither is specified. If your
|
||||
|
// application needs to refresh access tokens when the user
|
||||
|
// is not present at the browser, then use offline. This will
|
||||
|
// result in your application obtaining a refresh token the
|
||||
|
// first time your application exchanges an authorization
|
||||
|
// code for a user.
|
||||
|
AccessTypeOnline AuthCodeOption = SetAuthURLParam("access_type", "online") |
||||
|
AccessTypeOffline AuthCodeOption = SetAuthURLParam("access_type", "offline") |
||||
|
|
||||
|
// ApprovalForce forces the users to view the consent dialog
|
||||
|
// and confirm the permissions request at the URL returned
|
||||
|
// from AuthCodeURL, even if they've already done so.
|
||||
|
ApprovalForce AuthCodeOption = SetAuthURLParam("approval_prompt", "force") |
||||
|
) |
||||
|
|
||||
|
// An AuthCodeOption is passed to Config.AuthCodeURL.
|
||||
|
type AuthCodeOption interface { |
||||
|
setValue(url.Values) |
||||
|
} |
||||
|
|
||||
|
type setParam struct{ k, v string } |
||||
|
|
||||
|
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } |
||||
|
|
||||
|
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
|
||||
|
// to a provider's authorization endpoint.
|
||||
|
func SetAuthURLParam(key, value string) AuthCodeOption { |
||||
|
return setParam{key, value} |
||||
|
} |
||||
|
|
||||
|
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||
|
// that asks for permissions for the required scopes explicitly.
|
||||
|
//
|
||||
|
// State is a token to protect the user from CSRF attacks. You must
|
||||
|
// always provide a non-zero string and validate that it matches the
|
||||
|
// the state query parameter on your redirect callback.
|
||||
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
|
//
|
||||
|
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||
|
// as ApprovalForce.
|
||||
|
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { |
||||
|
var buf bytes.Buffer |
||||
|
buf.WriteString(c.Endpoint.AuthURL) |
||||
|
v := url.Values{ |
||||
|
"response_type": {"code"}, |
||||
|
"client_id": {c.ClientID}, |
||||
|
"redirect_uri": internal.CondVal(c.RedirectURL), |
||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
||||
|
"state": internal.CondVal(state), |
||||
|
} |
||||
|
for _, opt := range opts { |
||||
|
opt.setValue(v) |
||||
|
} |
||||
|
if strings.Contains(c.Endpoint.AuthURL, "?") { |
||||
|
buf.WriteByte('&') |
||||
|
} else { |
||||
|
buf.WriteByte('?') |
||||
|
} |
||||
|
buf.WriteString(v.Encode()) |
||||
|
return buf.String() |
||||
|
} |
||||
|
|
||||
|
// PasswordCredentialsToken converts a resource owner username and password
|
||||
|
// pair into a token.
|
||||
|
//
|
||||
|
// Per the RFC, this grant type should only be used "when there is a high
|
||||
|
// degree of trust between the resource owner and the client (e.g., the client
|
||||
|
// is part of the device operating system or a highly privileged application),
|
||||
|
// and when other authorization grant types are not available."
|
||||
|
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
|
||||
|
//
|
||||
|
// The HTTP client to use is derived from the context.
|
||||
|
// If nil, http.DefaultClient is used.
|
||||
|
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { |
||||
|
return retrieveToken(ctx, c, url.Values{ |
||||
|
"grant_type": {"password"}, |
||||
|
"username": {username}, |
||||
|
"password": {password}, |
||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// Exchange converts an authorization code into a token.
|
||||
|
//
|
||||
|
// It is used after a resource provider redirects the user back
|
||||
|
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
||||
|
//
|
||||
|
// The HTTP client to use is derived from the context.
|
||||
|
// If a client is not provided via the context, http.DefaultClient is used.
|
||||
|
//
|
||||
|
// The code will be in the *http.Request.FormValue("code"). Before
|
||||
|
// calling Exchange, be sure to validate FormValue("state").
|
||||
|
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { |
||||
|
return retrieveToken(ctx, c, url.Values{ |
||||
|
"grant_type": {"authorization_code"}, |
||||
|
"code": {code}, |
||||
|
"redirect_uri": internal.CondVal(c.RedirectURL), |
||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
// Client returns an HTTP client using the provided token.
|
||||
|
// The token will auto-refresh as necessary. The underlying
|
||||
|
// HTTP transport will be obtained using the provided context.
|
||||
|
// The returned client and its Transport should not be modified.
|
||||
|
func (c *Config) Client(ctx context.Context, t *Token) *http.Client { |
||||
|
return NewClient(ctx, c.TokenSource(ctx, t)) |
||||
|
} |
||||
|
|
||||
|
// TokenSource returns a TokenSource that returns t until t expires,
|
||||
|
// automatically refreshing it as necessary using the provided context.
|
||||
|
//
|
||||
|
// Most users will use Config.Client instead.
|
||||
|
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { |
||||
|
tkr := &tokenRefresher{ |
||||
|
ctx: ctx, |
||||
|
conf: c, |
||||
|
} |
||||
|
if t != nil { |
||||
|
tkr.refreshToken = t.RefreshToken |
||||
|
} |
||||
|
return &reuseTokenSource{ |
||||
|
t: t, |
||||
|
new: tkr, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
||||
|
// HTTP requests to renew a token using a RefreshToken.
|
||||
|
type tokenRefresher struct { |
||||
|
ctx context.Context // used to get HTTP requests
|
||||
|
conf *Config |
||||
|
refreshToken string |
||||
|
} |
||||
|
|
||||
|
// WARNING: Token is not safe for concurrent access, as it
|
||||
|
// updates the tokenRefresher's refreshToken field.
|
||||
|
// Within this package, it is used by reuseTokenSource which
|
||||
|
// synchronizes calls to this method with its own mutex.
|
||||
|
func (tf *tokenRefresher) Token() (*Token, error) { |
||||
|
if tf.refreshToken == "" { |
||||
|
return nil, errors.New("oauth2: token expired and refresh token is not set") |
||||
|
} |
||||
|
|
||||
|
tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ |
||||
|
"grant_type": {"refresh_token"}, |
||||
|
"refresh_token": {tf.refreshToken}, |
||||
|
}) |
||||
|
|
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if tf.refreshToken != tk.RefreshToken { |
||||
|
tf.refreshToken = tk.RefreshToken |
||||
|
} |
||||
|
return tk, err |
||||
|
} |
||||
|
|
||||
|
// reuseTokenSource is a TokenSource that holds a single token in memory
|
||||
|
// and validates its expiry before each call to retrieve it with
|
||||
|
// Token. If it's expired, it will be auto-refreshed using the
|
||||
|
// new TokenSource.
|
||||
|
type reuseTokenSource struct { |
||||
|
new TokenSource // called when t is expired.
|
||||
|
|
||||
|
mu sync.Mutex // guards t
|
||||
|
t *Token |
||||
|
} |
||||
|
|
||||
|
// Token returns the current token if it's still valid, else will
|
||||
|
// refresh the current token (using r.Context for HTTP client
|
||||
|
// information) and return the new one.
|
||||
|
func (s *reuseTokenSource) Token() (*Token, error) { |
||||
|
s.mu.Lock() |
||||
|
defer s.mu.Unlock() |
||||
|
if s.t.Valid() { |
||||
|
return s.t, nil |
||||
|
} |
||||
|
t, err := s.new.Token() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
s.t = t |
||||
|
return t, nil |
||||
|
} |
||||
|
|
||||
|
// StaticTokenSource returns a TokenSource that always returns the same token.
|
||||
|
// Because the provided token t is never refreshed, StaticTokenSource is only
|
||||
|
// useful for tokens that never expire.
|
||||
|
func StaticTokenSource(t *Token) TokenSource { |
||||
|
return staticTokenSource{t} |
||||
|
} |
||||
|
|
||||
|
// staticTokenSource is a TokenSource that always returns the same Token.
|
||||
|
type staticTokenSource struct { |
||||
|
t *Token |
||||
|
} |
||||
|
|
||||
|
func (s staticTokenSource) Token() (*Token, error) { |
||||
|
return s.t, nil |
||||
|
} |
||||
|
|
||||
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||
|
// WithValue function to associate an *http.Client value with a context.
|
||||
|
var HTTPClient internal.ContextKey |
||||
|
|
||||
|
// NewClient creates an *http.Client from a Context and TokenSource.
|
||||
|
// The returned client is not valid beyond the lifetime of the context.
|
||||
|
//
|
||||
|
// As a special case, if src is nil, a non-OAuth2 client is returned
|
||||
|
// using the provided context. This exists to support related OAuth2
|
||||
|
// packages.
|
||||
|
func NewClient(ctx context.Context, src TokenSource) *http.Client { |
||||
|
if src == nil { |
||||
|
c, err := internal.ContextClient(ctx) |
||||
|
if err != nil { |
||||
|
return &http.Client{Transport: internal.ErrorTransport{Err: err}} |
||||
|
} |
||||
|
return c |
||||
|
} |
||||
|
return &http.Client{ |
||||
|
Transport: &Transport{ |
||||
|
Base: internal.ContextTransport(ctx), |
||||
|
Source: ReuseTokenSource(nil, src), |
||||
|
}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// ReuseTokenSource returns a TokenSource which repeatedly returns the
|
||||
|
// same token as long as it's valid, starting with t.
|
||||
|
// When its cached token is invalid, a new token is obtained from src.
|
||||
|
//
|
||||
|
// ReuseTokenSource is typically used to reuse tokens from a cache
|
||||
|
// (such as a file on disk) between runs of a program, rather than
|
||||
|
// obtaining new tokens unnecessarily.
|
||||
|
//
|
||||
|
// The initial token t may be nil, in which case the TokenSource is
|
||||
|
// wrapped in a caching version if it isn't one already. This also
|
||||
|
// means it's always safe to wrap ReuseTokenSource around any other
|
||||
|
// TokenSource without adverse effects.
|
||||
|
func ReuseTokenSource(t *Token, src TokenSource) TokenSource { |
||||
|
// Don't wrap a reuseTokenSource in itself. That would work,
|
||||
|
// but cause an unnecessary number of mutex operations.
|
||||
|
// Just build the equivalent one.
|
||||
|
if rt, ok := src.(*reuseTokenSource); ok { |
||||
|
if t == nil { |
||||
|
// Just use it directly.
|
||||
|
return rt |
||||
|
} |
||||
|
src = rt.new |
||||
|
} |
||||
|
return &reuseTokenSource{ |
||||
|
t: t, |
||||
|
new: src, |
||||
|
} |
||||
|
} |
@ -0,0 +1,458 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io/ioutil" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"net/url" |
||||
|
"reflect" |
||||
|
"strconv" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
type mockTransport struct { |
||||
|
rt func(req *http.Request) (resp *http.Response, err error) |
||||
|
} |
||||
|
|
||||
|
func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { |
||||
|
return t.rt(req) |
||||
|
} |
||||
|
|
||||
|
func newConf(url string) *Config { |
||||
|
return &Config{ |
||||
|
ClientID: "CLIENT_ID", |
||||
|
ClientSecret: "CLIENT_SECRET", |
||||
|
RedirectURL: "REDIRECT_URL", |
||||
|
Scopes: []string{"scope1", "scope2"}, |
||||
|
Endpoint: Endpoint{ |
||||
|
AuthURL: url + "/auth", |
||||
|
TokenURL: url + "/token", |
||||
|
}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestAuthCodeURL(t *testing.T) { |
||||
|
conf := newConf("server") |
||||
|
url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce) |
||||
|
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { |
||||
|
t.Errorf("Auth code URL doesn't match the expected, found: %v", url) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestAuthCodeURL_CustomParam(t *testing.T) { |
||||
|
conf := newConf("server") |
||||
|
param := SetAuthURLParam("foo", "bar") |
||||
|
url := conf.AuthCodeURL("baz", param) |
||||
|
if url != "server/auth?client_id=CLIENT_ID&foo=bar&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=baz" { |
||||
|
t.Errorf("Auth code URL doesn't match the expected, found: %v", url) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestAuthCodeURL_Optional(t *testing.T) { |
||||
|
conf := &Config{ |
||||
|
ClientID: "CLIENT_ID", |
||||
|
Endpoint: Endpoint{ |
||||
|
AuthURL: "/auth-url", |
||||
|
TokenURL: "/token-url", |
||||
|
}, |
||||
|
} |
||||
|
url := conf.AuthCodeURL("") |
||||
|
if url != "/auth-url?client_id=CLIENT_ID&response_type=code" { |
||||
|
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExchangeRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
||||
|
} |
||||
|
headerAuth := r.Header.Get("Authorization") |
||||
|
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
||||
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
if headerContentType != "application/x-www-form-urlencoded" { |
||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
||||
|
} |
||||
|
body, err := ioutil.ReadAll(r.Body) |
||||
|
if err != nil { |
||||
|
t.Errorf("Failed reading request body: %s.", err) |
||||
|
} |
||||
|
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { |
||||
|
t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
||||
|
} |
||||
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tok, err := conf.Exchange(NoContext, "exchange-code") |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Fatalf("Token invalid. Got: %#v", tok) |
||||
|
} |
||||
|
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
||||
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
||||
|
} |
||||
|
if tok.TokenType != "bearer" { |
||||
|
t.Errorf("Unexpected token type, %#v.", tok.TokenType) |
||||
|
} |
||||
|
scope := tok.Extra("scope") |
||||
|
if scope != "user" { |
||||
|
t.Errorf("Unexpected value for scope: %v", scope) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExchangeRequest_JSONResponse(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
||||
|
} |
||||
|
headerAuth := r.Header.Get("Authorization") |
||||
|
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
||||
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
if headerContentType != "application/x-www-form-urlencoded" { |
||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
||||
|
} |
||||
|
body, err := ioutil.ReadAll(r.Body) |
||||
|
if err != nil { |
||||
|
t.Errorf("Failed reading request body: %s.", err) |
||||
|
} |
||||
|
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { |
||||
|
t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
||||
|
} |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tok, err := conf.Exchange(NoContext, "exchange-code") |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Fatalf("Token invalid. Got: %#v", tok) |
||||
|
} |
||||
|
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
||||
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
||||
|
} |
||||
|
if tok.TokenType != "bearer" { |
||||
|
t.Errorf("Unexpected token type, %#v.", tok.TokenType) |
||||
|
} |
||||
|
scope := tok.Extra("scope") |
||||
|
if scope != "user" { |
||||
|
t.Errorf("Unexpected value for scope: %v", scope) |
||||
|
} |
||||
|
expiresIn := tok.Extra("expires_in") |
||||
|
if expiresIn != float64(86400) { |
||||
|
t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExtraValueRetrieval(t *testing.T) { |
||||
|
values := url.Values{} |
||||
|
|
||||
|
kvmap := map[string]string{ |
||||
|
"scope": "user", "token_type": "bearer", "expires_in": "86400.92", |
||||
|
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1", |
||||
|
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400", |
||||
|
"untrimmed": " untrimmed ", |
||||
|
} |
||||
|
|
||||
|
for key, value := range kvmap { |
||||
|
values.Set(key, value) |
||||
|
} |
||||
|
|
||||
|
tok := Token{ |
||||
|
raw: values, |
||||
|
} |
||||
|
|
||||
|
scope := tok.Extra("scope") |
||||
|
if scope != "user" { |
||||
|
t.Errorf("Unexpected scope %v wanted \"user\"", scope) |
||||
|
} |
||||
|
serverTime := tok.Extra("server_time") |
||||
|
if serverTime != 1443571905.5606415 { |
||||
|
t.Errorf("Unexpected non-float64 value for server_time: %v", serverTime) |
||||
|
} |
||||
|
refererIp := tok.Extra("referer_ip") |
||||
|
if refererIp != "10.0.0.1" { |
||||
|
t.Errorf("Unexpected non-string value for referer_ip: %v", refererIp) |
||||
|
} |
||||
|
expires_in := tok.Extra("expires_in") |
||||
|
if expires_in != 86400.92 { |
||||
|
t.Errorf("Unexpected value for expires_in, wanted 86400 got %v", expires_in) |
||||
|
} |
||||
|
requestId := tok.Extra("request_id") |
||||
|
if requestId != int64(86400) { |
||||
|
t.Errorf("Unexpected non-int64 value for request_id: %v", requestId) |
||||
|
} |
||||
|
untrimmed := tok.Extra("untrimmed") |
||||
|
if untrimmed != " untrimmed " { |
||||
|
t.Errorf("Unexpected value for untrimmed, got %q expected \" untrimmed \"", untrimmed) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
const day = 24 * time.Hour |
||||
|
|
||||
|
func TestExchangeRequest_JSONResponse_Expiry(t *testing.T) { |
||||
|
seconds := int32(day.Seconds()) |
||||
|
jsonNumberType := reflect.TypeOf(json.Number("0")) |
||||
|
for _, c := range []struct { |
||||
|
expires string |
||||
|
expect error |
||||
|
}{ |
||||
|
{fmt.Sprintf(`"expires_in": %d`, seconds), nil}, |
||||
|
{fmt.Sprintf(`"expires_in": "%d"`, seconds), nil}, // PayPal case
|
||||
|
{fmt.Sprintf(`"expires": %d`, seconds), nil}, // Facebook case
|
||||
|
{`"expires": false`, &json.UnmarshalTypeError{Value: "bool", Type: jsonNumberType}}, // wrong type
|
||||
|
{`"expires": {}`, &json.UnmarshalTypeError{Value: "object", Type: jsonNumberType}}, // wrong type
|
||||
|
{`"expires": "zzz"`, &strconv.NumError{Func: "ParseInt", Num: "zzz", Err: strconv.ErrSyntax}}, // wrong value
|
||||
|
} { |
||||
|
testExchangeRequest_JSONResponse_expiry(t, c.expires, c.expect) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, expect error) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
t1 := time.Now().Add(day) |
||||
|
tok, err := conf.Exchange(NoContext, "exchange-code") |
||||
|
t2 := time.Now().Add(day) |
||||
|
// Do a fmt.Sprint comparison so either side can be
|
||||
|
// nil. fmt.Sprint just stringifies them to "<nil>", and no
|
||||
|
// non-nil expected error ever stringifies as "<nil>", so this
|
||||
|
// isn't terribly disgusting. We do this because Go 1.4 and
|
||||
|
// Go 1.5 return a different deep value for
|
||||
|
// json.UnmarshalTypeError. In Go 1.5, the
|
||||
|
// json.UnmarshalTypeError contains a new field with a new
|
||||
|
// non-zero value. Rather than ignore it here with reflect or
|
||||
|
// add new files and +build tags, just look at the strings.
|
||||
|
if fmt.Sprint(err) != fmt.Sprint(expect) { |
||||
|
t.Errorf("Error = %v; want %v", err, expect) |
||||
|
} |
||||
|
if err != nil { |
||||
|
return |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Fatalf("Token invalid. Got: %#v", tok) |
||||
|
} |
||||
|
expiry := tok.Expiry |
||||
|
if expiry.Before(t1) || expiry.After(t2) { |
||||
|
t.Errorf("Unexpected value for Expiry: %v (shold be between %v and %v)", expiry, t1, t2) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExchangeRequest_BadResponse(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tok, err := conf.Exchange(NoContext, "code") |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
if tok.AccessToken != "" { |
||||
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExchangeRequest_BadResponseType(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
_, err := conf.Exchange(NoContext, "exchange-code") |
||||
|
if err == nil { |
||||
|
t.Error("expected error from invalid access_token type") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExchangeRequest_NonBasicAuth(t *testing.T) { |
||||
|
tr := &mockTransport{ |
||||
|
rt: func(r *http.Request) (w *http.Response, err error) { |
||||
|
headerAuth := r.Header.Get("Authorization") |
||||
|
if headerAuth != "" { |
||||
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
||||
|
} |
||||
|
return nil, errors.New("no response") |
||||
|
}, |
||||
|
} |
||||
|
c := &http.Client{Transport: tr} |
||||
|
conf := &Config{ |
||||
|
ClientID: "CLIENT_ID", |
||||
|
Endpoint: Endpoint{ |
||||
|
AuthURL: "https://accounts.google.com/auth", |
||||
|
TokenURL: "https://accounts.google.com/token", |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
ctx := context.WithValue(context.Background(), HTTPClient, c) |
||||
|
conf.Exchange(ctx, "code") |
||||
|
} |
||||
|
|
||||
|
func TestPasswordCredentialsTokenRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
defer r.Body.Close() |
||||
|
expected := "/token" |
||||
|
if r.URL.String() != expected { |
||||
|
t.Errorf("URL = %q; want %q", r.URL, expected) |
||||
|
} |
||||
|
headerAuth := r.Header.Get("Authorization") |
||||
|
expected = "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" |
||||
|
if headerAuth != expected { |
||||
|
t.Errorf("Authorization header = %q; want %q", headerAuth, expected) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
expected = "application/x-www-form-urlencoded" |
||||
|
if headerContentType != expected { |
||||
|
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected) |
||||
|
} |
||||
|
body, err := ioutil.ReadAll(r.Body) |
||||
|
if err != nil { |
||||
|
t.Errorf("Failed reading request body: %s.", err) |
||||
|
} |
||||
|
expected = "client_id=CLIENT_ID&grant_type=password&password=password1&scope=scope1+scope2&username=user1" |
||||
|
if string(body) != expected { |
||||
|
t.Errorf("res.Body = %q; want %q", string(body), expected) |
||||
|
} |
||||
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tok, err := conf.PasswordCredentialsToken(NoContext, "user1", "password1") |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if !tok.Valid() { |
||||
|
t.Fatalf("Token invalid. Got: %#v", tok) |
||||
|
} |
||||
|
expected := "90d64460d14870c08c81352a05dedd3465940a7c" |
||||
|
if tok.AccessToken != expected { |
||||
|
t.Errorf("AccessToken = %q; want %q", tok.AccessToken, expected) |
||||
|
} |
||||
|
expected = "bearer" |
||||
|
if tok.TokenType != expected { |
||||
|
t.Errorf("TokenType = %q; want %q", tok.TokenType, expected) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTokenRefreshRequest(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() == "/somethingelse" { |
||||
|
return |
||||
|
} |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
if headerContentType != "application/x-www-form-urlencoded" { |
||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
||||
|
} |
||||
|
body, _ := ioutil.ReadAll(r.Body) |
||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { |
||||
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
||||
|
} |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
c := conf.Client(NoContext, &Token{RefreshToken: "REFRESH_TOKEN"}) |
||||
|
c.Get(ts.URL + "/somethingelse") |
||||
|
} |
||||
|
|
||||
|
func TestFetchWithNoRefreshToken(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.URL.String() == "/somethingelse" { |
||||
|
return |
||||
|
} |
||||
|
if r.URL.String() != "/token" { |
||||
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
||||
|
} |
||||
|
headerContentType := r.Header.Get("Content-Type") |
||||
|
if headerContentType != "application/x-www-form-urlencoded" { |
||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
||||
|
} |
||||
|
body, _ := ioutil.ReadAll(r.Body) |
||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { |
||||
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
||||
|
} |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
c := conf.Client(NoContext, nil) |
||||
|
_, err := c.Get(ts.URL + "/somethingelse") |
||||
|
if err == nil { |
||||
|
t.Errorf("Fetch should return an error if no refresh token is set") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
w.Header().Set("Content-Type", "application/json") |
||||
|
w.Write([]byte(`{"access_token":"ACCESS TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW REFRESH TOKEN"}`)) |
||||
|
return |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
tkr := tokenRefresher{ |
||||
|
conf: conf, |
||||
|
ctx: NoContext, |
||||
|
refreshToken: "OLD REFRESH TOKEN", |
||||
|
} |
||||
|
tk, err := tkr.Token() |
||||
|
if err != nil { |
||||
|
t.Errorf("Unexpected refreshToken error returned: %v", err) |
||||
|
return |
||||
|
} |
||||
|
if tk.RefreshToken != tkr.refreshToken { |
||||
|
t.Errorf("tokenRefresher.refresh_token = %s; want %s", tkr.refreshToken, tk.RefreshToken) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestConfigClientWithToken(t *testing.T) { |
||||
|
tok := &Token{ |
||||
|
AccessToken: "abc123", |
||||
|
} |
||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if got, want := r.Header.Get("Authorization"), fmt.Sprintf("Bearer %s", tok.AccessToken); got != want { |
||||
|
t.Errorf("Authorization header = %q; want %q", got, want) |
||||
|
} |
||||
|
return |
||||
|
})) |
||||
|
defer ts.Close() |
||||
|
conf := newConf(ts.URL) |
||||
|
|
||||
|
c := conf.Client(NoContext, tok) |
||||
|
req, err := http.NewRequest("GET", ts.URL, nil) |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
_, err = c.Do(req) |
||||
|
if err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package odnoklassniki provides constants for using OAuth2 to access Odnoklassniki.
|
||||
|
package odnoklassniki // import "golang.org/x/oauth2/odnoklassniki"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Odnoklassniki's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.odnoklassniki.ru/oauth/authorize", |
||||
|
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do", |
||||
|
} |
@ -0,0 +1,22 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package paypal provides constants for using OAuth2 to access PayPal.
|
||||
|
package paypal // import "golang.org/x/oauth2/paypal"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is PayPal's OAuth 2.0 endpoint in live (production) environment.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize", |
||||
|
TokenURL: "https://api.paypal.com/v1/identity/openidconnect/tokenservice", |
||||
|
} |
||||
|
|
||||
|
// SandboxEndpoint is PayPal's OAuth 2.0 endpoint in sandbox (testing) environment.
|
||||
|
var SandboxEndpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://www.sandbox.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize", |
||||
|
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice", |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package slack provides constants for using OAuth2 to access Slack.
|
||||
|
package slack // import "golang.org/x/oauth2/slack"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is Slack's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://slack.com/oauth/authorize", |
||||
|
TokenURL: "https://slack.com/api/oauth.access", |
||||
|
} |
@ -0,0 +1,158 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2/internal" |
||||
|
) |
||||
|
|
||||
|
// expiryDelta determines how earlier a token should be considered
|
||||
|
// expired than its actual expiration time. It is used to avoid late
|
||||
|
// expirations due to client-server time mismatches.
|
||||
|
const expiryDelta = 10 * time.Second |
||||
|
|
||||
|
// Token represents the crendentials used to authorize
|
||||
|
// the requests to access protected resources on the OAuth 2.0
|
||||
|
// provider's backend.
|
||||
|
//
|
||||
|
// Most users of this package should not access fields of Token
|
||||
|
// directly. They're exported mostly for use by related packages
|
||||
|
// implementing derivative OAuth2 flows.
|
||||
|
type Token struct { |
||||
|
// AccessToken is the token that authorizes and authenticates
|
||||
|
// the requests.
|
||||
|
AccessToken string `json:"access_token"` |
||||
|
|
||||
|
// TokenType is the type of token.
|
||||
|
// The Type method returns either this or "Bearer", the default.
|
||||
|
TokenType string `json:"token_type,omitempty"` |
||||
|
|
||||
|
// RefreshToken is a token that's used by the application
|
||||
|
// (as opposed to the user) to refresh the access token
|
||||
|
// if it expires.
|
||||
|
RefreshToken string `json:"refresh_token,omitempty"` |
||||
|
|
||||
|
// Expiry is the optional expiration time of the access token.
|
||||
|
//
|
||||
|
// If zero, TokenSource implementations will reuse the same
|
||||
|
// token forever and RefreshToken or equivalent
|
||||
|
// mechanisms for that TokenSource will not be used.
|
||||
|
Expiry time.Time `json:"expiry,omitempty"` |
||||
|
|
||||
|
// raw optionally contains extra metadata from the server
|
||||
|
// when updating a token.
|
||||
|
raw interface{} |
||||
|
} |
||||
|
|
||||
|
// Type returns t.TokenType if non-empty, else "Bearer".
|
||||
|
func (t *Token) Type() string { |
||||
|
if strings.EqualFold(t.TokenType, "bearer") { |
||||
|
return "Bearer" |
||||
|
} |
||||
|
if strings.EqualFold(t.TokenType, "mac") { |
||||
|
return "MAC" |
||||
|
} |
||||
|
if strings.EqualFold(t.TokenType, "basic") { |
||||
|
return "Basic" |
||||
|
} |
||||
|
if t.TokenType != "" { |
||||
|
return t.TokenType |
||||
|
} |
||||
|
return "Bearer" |
||||
|
} |
||||
|
|
||||
|
// SetAuthHeader sets the Authorization header to r using the access
|
||||
|
// token in t.
|
||||
|
//
|
||||
|
// This method is unnecessary when using Transport or an HTTP Client
|
||||
|
// returned by this package.
|
||||
|
func (t *Token) SetAuthHeader(r *http.Request) { |
||||
|
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) |
||||
|
} |
||||
|
|
||||
|
// WithExtra returns a new Token that's a clone of t, but using the
|
||||
|
// provided raw extra map. This is only intended for use by packages
|
||||
|
// implementing derivative OAuth2 flows.
|
||||
|
func (t *Token) WithExtra(extra interface{}) *Token { |
||||
|
t2 := new(Token) |
||||
|
*t2 = *t |
||||
|
t2.raw = extra |
||||
|
return t2 |
||||
|
} |
||||
|
|
||||
|
// Extra returns an extra field.
|
||||
|
// Extra fields are key-value pairs returned by the server as a
|
||||
|
// part of the token retrieval response.
|
||||
|
func (t *Token) Extra(key string) interface{} { |
||||
|
if raw, ok := t.raw.(map[string]interface{}); ok { |
||||
|
return raw[key] |
||||
|
} |
||||
|
|
||||
|
vals, ok := t.raw.(url.Values) |
||||
|
if !ok { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
v := vals.Get(key) |
||||
|
switch s := strings.TrimSpace(v); strings.Count(s, ".") { |
||||
|
case 0: // Contains no "."; try to parse as int
|
||||
|
if i, err := strconv.ParseInt(s, 10, 64); err == nil { |
||||
|
return i |
||||
|
} |
||||
|
case 1: // Contains a single "."; try to parse as float
|
||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil { |
||||
|
return f |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return v |
||||
|
} |
||||
|
|
||||
|
// expired reports whether the token is expired.
|
||||
|
// t must be non-nil.
|
||||
|
func (t *Token) expired() bool { |
||||
|
if t.Expiry.IsZero() { |
||||
|
return false |
||||
|
} |
||||
|
return t.Expiry.Add(-expiryDelta).Before(time.Now()) |
||||
|
} |
||||
|
|
||||
|
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
|
||||
|
func (t *Token) Valid() bool { |
||||
|
return t != nil && t.AccessToken != "" && !t.expired() |
||||
|
} |
||||
|
|
||||
|
// tokenFromInternal maps an *internal.Token struct into
|
||||
|
// a *Token struct.
|
||||
|
func tokenFromInternal(t *internal.Token) *Token { |
||||
|
if t == nil { |
||||
|
return nil |
||||
|
} |
||||
|
return &Token{ |
||||
|
AccessToken: t.AccessToken, |
||||
|
TokenType: t.TokenType, |
||||
|
RefreshToken: t.RefreshToken, |
||||
|
Expiry: t.Expiry, |
||||
|
raw: t.Raw, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
||||
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
||||
|
// with an error..
|
||||
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { |
||||
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return tokenFromInternal(tk), nil |
||||
|
} |
@ -0,0 +1,72 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestTokenExtra(t *testing.T) { |
||||
|
type testCase struct { |
||||
|
key string |
||||
|
val interface{} |
||||
|
want interface{} |
||||
|
} |
||||
|
const key = "extra-key" |
||||
|
cases := []testCase{ |
||||
|
{key: key, val: "abc", want: "abc"}, |
||||
|
{key: key, val: 123, want: 123}, |
||||
|
{key: key, val: "", want: ""}, |
||||
|
{key: "other-key", val: "def", want: nil}, |
||||
|
} |
||||
|
for _, tc := range cases { |
||||
|
extra := make(map[string]interface{}) |
||||
|
extra[tc.key] = tc.val |
||||
|
tok := &Token{raw: extra} |
||||
|
if got, want := tok.Extra(key), tc.want; got != want { |
||||
|
t.Errorf("Extra(%q) = %q; want %q", key, got, want) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTokenExpiry(t *testing.T) { |
||||
|
now := time.Now() |
||||
|
cases := []struct { |
||||
|
name string |
||||
|
tok *Token |
||||
|
want bool |
||||
|
}{ |
||||
|
{name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false}, |
||||
|
{name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true}, |
||||
|
{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true}, |
||||
|
} |
||||
|
for _, tc := range cases { |
||||
|
if got, want := tc.tok.expired(), tc.want; got != want { |
||||
|
t.Errorf("expired (%q) = %v; want %v", tc.name, got, want) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTokenTypeMethod(t *testing.T) { |
||||
|
cases := []struct { |
||||
|
name string |
||||
|
tok *Token |
||||
|
want string |
||||
|
}{ |
||||
|
{name: "bearer-mixed_case", tok: &Token{TokenType: "beAREr"}, want: "Bearer"}, |
||||
|
{name: "default-bearer", tok: &Token{}, want: "Bearer"}, |
||||
|
{name: "basic", tok: &Token{TokenType: "basic"}, want: "Basic"}, |
||||
|
{name: "basic-capitalized", tok: &Token{TokenType: "Basic"}, want: "Basic"}, |
||||
|
{name: "mac", tok: &Token{TokenType: "mac"}, want: "MAC"}, |
||||
|
{name: "mac-caps", tok: &Token{TokenType: "MAC"}, want: "MAC"}, |
||||
|
{name: "mac-mixed_case", tok: &Token{TokenType: "mAc"}, want: "MAC"}, |
||||
|
} |
||||
|
for _, tc := range cases { |
||||
|
if got, want := tc.tok.Type(), tc.want; got != want { |
||||
|
t.Errorf("TokenType(%q) = %v; want %v", tc.name, got, want) |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,132 @@ |
|||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"io" |
||||
|
"net/http" |
||||
|
"sync" |
||||
|
) |
||||
|
|
||||
|
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
||||
|
// wrapping a base RoundTripper and adding an Authorization header
|
||||
|
// with a token from the supplied Sources.
|
||||
|
//
|
||||
|
// Transport is a low-level mechanism. Most code will use the
|
||||
|
// higher-level Config.Client method instead.
|
||||
|
type Transport struct { |
||||
|
// Source supplies the token to add to outgoing requests'
|
||||
|
// Authorization headers.
|
||||
|
Source TokenSource |
||||
|
|
||||
|
// Base is the base RoundTripper used to make HTTP requests.
|
||||
|
// If nil, http.DefaultTransport is used.
|
||||
|
Base http.RoundTripper |
||||
|
|
||||
|
mu sync.Mutex // guards modReq
|
||||
|
modReq map[*http.Request]*http.Request // original -> modified
|
||||
|
} |
||||
|
|
||||
|
// RoundTrip authorizes and authenticates the request with an
|
||||
|
// access token. If no token exists or token is expired,
|
||||
|
// tries to refresh/fetch a new token.
|
||||
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { |
||||
|
if t.Source == nil { |
||||
|
return nil, errors.New("oauth2: Transport's Source is nil") |
||||
|
} |
||||
|
token, err := t.Source.Token() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
req2 := cloneRequest(req) // per RoundTripper contract
|
||||
|
token.SetAuthHeader(req2) |
||||
|
t.setModReq(req, req2) |
||||
|
res, err := t.base().RoundTrip(req2) |
||||
|
if err != nil { |
||||
|
t.setModReq(req, nil) |
||||
|
return nil, err |
||||
|
} |
||||
|
res.Body = &onEOFReader{ |
||||
|
rc: res.Body, |
||||
|
fn: func() { t.setModReq(req, nil) }, |
||||
|
} |
||||
|
return res, nil |
||||
|
} |
||||
|
|
||||
|
// CancelRequest cancels an in-flight request by closing its connection.
|
||||
|
func (t *Transport) CancelRequest(req *http.Request) { |
||||
|
type canceler interface { |
||||
|
CancelRequest(*http.Request) |
||||
|
} |
||||
|
if cr, ok := t.base().(canceler); ok { |
||||
|
t.mu.Lock() |
||||
|
modReq := t.modReq[req] |
||||
|
delete(t.modReq, req) |
||||
|
t.mu.Unlock() |
||||
|
cr.CancelRequest(modReq) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (t *Transport) base() http.RoundTripper { |
||||
|
if t.Base != nil { |
||||
|
return t.Base |
||||
|
} |
||||
|
return http.DefaultTransport |
||||
|
} |
||||
|
|
||||
|
func (t *Transport) setModReq(orig, mod *http.Request) { |
||||
|
t.mu.Lock() |
||||
|
defer t.mu.Unlock() |
||||
|
if t.modReq == nil { |
||||
|
t.modReq = make(map[*http.Request]*http.Request) |
||||
|
} |
||||
|
if mod == nil { |
||||
|
delete(t.modReq, orig) |
||||
|
} else { |
||||
|
t.modReq[orig] = mod |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// cloneRequest returns a clone of the provided *http.Request.
|
||||
|
// The clone is a shallow copy of the struct and its Header map.
|
||||
|
func cloneRequest(r *http.Request) *http.Request { |
||||
|
// shallow copy of the struct
|
||||
|
r2 := new(http.Request) |
||||
|
*r2 = *r |
||||
|
// deep copy of the Header
|
||||
|
r2.Header = make(http.Header, len(r.Header)) |
||||
|
for k, s := range r.Header { |
||||
|
r2.Header[k] = append([]string(nil), s...) |
||||
|
} |
||||
|
return r2 |
||||
|
} |
||||
|
|
||||
|
type onEOFReader struct { |
||||
|
rc io.ReadCloser |
||||
|
fn func() |
||||
|
} |
||||
|
|
||||
|
func (r *onEOFReader) Read(p []byte) (n int, err error) { |
||||
|
n, err = r.rc.Read(p) |
||||
|
if err == io.EOF { |
||||
|
r.runFunc() |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
func (r *onEOFReader) Close() error { |
||||
|
err := r.rc.Close() |
||||
|
r.runFunc() |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
func (r *onEOFReader) runFunc() { |
||||
|
if fn := r.fn; fn != nil { |
||||
|
fn() |
||||
|
r.fn = nil |
||||
|
} |
||||
|
} |
@ -0,0 +1,108 @@ |
|||||
|
package oauth2 |
||||
|
|
||||
|
import ( |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
type tokenSource struct{ token *Token } |
||||
|
|
||||
|
func (t *tokenSource) Token() (*Token, error) { |
||||
|
return t.token, nil |
||||
|
} |
||||
|
|
||||
|
func TestTransportNilTokenSource(t *testing.T) { |
||||
|
tr := &Transport{} |
||||
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) |
||||
|
defer server.Close() |
||||
|
client := &http.Client{Transport: tr} |
||||
|
res, err := client.Get(server.URL) |
||||
|
if err == nil { |
||||
|
t.Errorf("a nil Source was passed into the transport expected an error") |
||||
|
} |
||||
|
if res != nil { |
||||
|
t.Errorf("expected a nil response, got %v", res) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTransportTokenSource(t *testing.T) { |
||||
|
ts := &tokenSource{ |
||||
|
token: &Token{ |
||||
|
AccessToken: "abc", |
||||
|
}, |
||||
|
} |
||||
|
tr := &Transport{ |
||||
|
Source: ts, |
||||
|
} |
||||
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if r.Header.Get("Authorization") != "Bearer abc" { |
||||
|
t.Errorf("Transport doesn't set the Authorization header from the fetched token") |
||||
|
} |
||||
|
}) |
||||
|
defer server.Close() |
||||
|
client := &http.Client{Transport: tr} |
||||
|
res, err := client.Get(server.URL) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
res.Body.Close() |
||||
|
} |
||||
|
|
||||
|
// Test for case-sensitive token types, per https://github.com/golang/oauth2/issues/113
|
||||
|
func TestTransportTokenSourceTypes(t *testing.T) { |
||||
|
const val = "abc" |
||||
|
tests := []struct { |
||||
|
key string |
||||
|
val string |
||||
|
want string |
||||
|
}{ |
||||
|
{key: "bearer", val: val, want: "Bearer abc"}, |
||||
|
{key: "mac", val: val, want: "MAC abc"}, |
||||
|
{key: "basic", val: val, want: "Basic abc"}, |
||||
|
} |
||||
|
for _, tc := range tests { |
||||
|
ts := &tokenSource{ |
||||
|
token: &Token{ |
||||
|
AccessToken: tc.val, |
||||
|
TokenType: tc.key, |
||||
|
}, |
||||
|
} |
||||
|
tr := &Transport{ |
||||
|
Source: ts, |
||||
|
} |
||||
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) { |
||||
|
if got, want := r.Header.Get("Authorization"), tc.want; got != want { |
||||
|
t.Errorf("Authorization header (%q) = %q; want %q", val, got, want) |
||||
|
} |
||||
|
}) |
||||
|
defer server.Close() |
||||
|
client := &http.Client{Transport: tr} |
||||
|
res, err := client.Get(server.URL) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
res.Body.Close() |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestTokenValidNoAccessToken(t *testing.T) { |
||||
|
token := &Token{} |
||||
|
if token.Valid() { |
||||
|
t.Errorf("Token should not be valid with no access token") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestExpiredWithExpiry(t *testing.T) { |
||||
|
token := &Token{ |
||||
|
Expiry: time.Now().Add(-5 * time.Hour), |
||||
|
} |
||||
|
if token.Valid() { |
||||
|
t.Errorf("Token should not be valid if it expired in the past") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { |
||||
|
return httptest.NewServer(http.HandlerFunc(handler)) |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
|
// Use of this source code is governed by a BSD-style
|
||||
|
// license that can be found in the LICENSE file.
|
||||
|
|
||||
|
// Package vk provides constants for using OAuth2 to access VK.com.
|
||||
|
package vk // import "golang.org/x/oauth2/vk"
|
||||
|
|
||||
|
import ( |
||||
|
"golang.org/x/oauth2" |
||||
|
) |
||||
|
|
||||
|
// Endpoint is VK's OAuth 2.0 endpoint.
|
||||
|
var Endpoint = oauth2.Endpoint{ |
||||
|
AuthURL: "https://oauth.vk.com/authorize", |
||||
|
TokenURL: "https://oauth.vk.com/access_token", |
||||
|
} |
@ -0,0 +1,438 @@ |
|||||
|
// Copyright 2014 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
// Package metadata provides access to Google Compute Engine (GCE)
|
||||
|
// metadata and API service accounts.
|
||||
|
//
|
||||
|
// This package is a wrapper around the GCE metadata service,
|
||||
|
// as documented at https://developers.google.com/compute/docs/metadata.
|
||||
|
package metadata // import "google.golang.org/cloud/compute/metadata"
|
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"fmt" |
||||
|
"io/ioutil" |
||||
|
"net" |
||||
|
"net/http" |
||||
|
"net/url" |
||||
|
"os" |
||||
|
"runtime" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/net/context/ctxhttp" |
||||
|
|
||||
|
"google.golang.org/cloud/internal" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
// metadataIP is the documented metadata server IP address.
|
||||
|
metadataIP = "169.254.169.254" |
||||
|
|
||||
|
// metadataHostEnv is the environment variable specifying the
|
||||
|
// GCE metadata hostname. If empty, the default value of
|
||||
|
// metadataIP ("169.254.169.254") is used instead.
|
||||
|
// This is variable name is not defined by any spec, as far as
|
||||
|
// I know; it was made up for the Go package.
|
||||
|
metadataHostEnv = "GCE_METADATA_HOST" |
||||
|
) |
||||
|
|
||||
|
type cachedValue struct { |
||||
|
k string |
||||
|
trim bool |
||||
|
mu sync.Mutex |
||||
|
v string |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
projID = &cachedValue{k: "project/project-id", trim: true} |
||||
|
projNum = &cachedValue{k: "project/numeric-project-id", trim: true} |
||||
|
instID = &cachedValue{k: "instance/id", trim: true} |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
metaClient = &http.Client{ |
||||
|
Transport: &internal.Transport{ |
||||
|
Base: &http.Transport{ |
||||
|
Dial: (&net.Dialer{ |
||||
|
Timeout: 2 * time.Second, |
||||
|
KeepAlive: 30 * time.Second, |
||||
|
}).Dial, |
||||
|
ResponseHeaderTimeout: 2 * time.Second, |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
subscribeClient = &http.Client{ |
||||
|
Transport: &internal.Transport{ |
||||
|
Base: &http.Transport{ |
||||
|
Dial: (&net.Dialer{ |
||||
|
Timeout: 2 * time.Second, |
||||
|
KeepAlive: 30 * time.Second, |
||||
|
}).Dial, |
||||
|
}, |
||||
|
}, |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
// NotDefinedError is returned when requested metadata is not defined.
|
||||
|
//
|
||||
|
// The underlying string is the suffix after "/computeMetadata/v1/".
|
||||
|
//
|
||||
|
// This error is not returned if the value is defined to be the empty
|
||||
|
// string.
|
||||
|
type NotDefinedError string |
||||
|
|
||||
|
func (suffix NotDefinedError) Error() string { |
||||
|
return fmt.Sprintf("metadata: GCE metadata %q not defined", string(suffix)) |
||||
|
} |
||||
|
|
||||
|
// Get returns a value from the metadata service.
|
||||
|
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
|
||||
|
//
|
||||
|
// If the GCE_METADATA_HOST environment variable is not defined, a default of
|
||||
|
// 169.254.169.254 will be used instead.
|
||||
|
//
|
||||
|
// If the requested metadata is not defined, the returned error will
|
||||
|
// be of type NotDefinedError.
|
||||
|
func Get(suffix string) (string, error) { |
||||
|
val, _, err := getETag(metaClient, suffix) |
||||
|
return val, err |
||||
|
} |
||||
|
|
||||
|
// getETag returns a value from the metadata service as well as the associated
|
||||
|
// ETag using the provided client. This func is otherwise equivalent to Get.
|
||||
|
func getETag(client *http.Client, suffix string) (value, etag string, err error) { |
||||
|
// Using a fixed IP makes it very difficult to spoof the metadata service in
|
||||
|
// a container, which is an important use-case for local testing of cloud
|
||||
|
// deployments. To enable spoofing of the metadata service, the environment
|
||||
|
// variable GCE_METADATA_HOST is first inspected to decide where metadata
|
||||
|
// requests shall go.
|
||||
|
host := os.Getenv(metadataHostEnv) |
||||
|
if host == "" { |
||||
|
// Using 169.254.169.254 instead of "metadata" here because Go
|
||||
|
// binaries built with the "netgo" tag and without cgo won't
|
||||
|
// know the search suffix for "metadata" is
|
||||
|
// ".google.internal", and this IP address is documented as
|
||||
|
// being stable anyway.
|
||||
|
host = metadataIP |
||||
|
} |
||||
|
url := "http://" + host + "/computeMetadata/v1/" + suffix |
||||
|
req, _ := http.NewRequest("GET", url, nil) |
||||
|
req.Header.Set("Metadata-Flavor", "Google") |
||||
|
res, err := client.Do(req) |
||||
|
if err != nil { |
||||
|
return "", "", err |
||||
|
} |
||||
|
defer res.Body.Close() |
||||
|
if res.StatusCode == http.StatusNotFound { |
||||
|
return "", "", NotDefinedError(suffix) |
||||
|
} |
||||
|
if res.StatusCode != 200 { |
||||
|
return "", "", fmt.Errorf("status code %d trying to fetch %s", res.StatusCode, url) |
||||
|
} |
||||
|
all, err := ioutil.ReadAll(res.Body) |
||||
|
if err != nil { |
||||
|
return "", "", err |
||||
|
} |
||||
|
return string(all), res.Header.Get("Etag"), nil |
||||
|
} |
||||
|
|
||||
|
func getTrimmed(suffix string) (s string, err error) { |
||||
|
s, err = Get(suffix) |
||||
|
s = strings.TrimSpace(s) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
func (c *cachedValue) get() (v string, err error) { |
||||
|
defer c.mu.Unlock() |
||||
|
c.mu.Lock() |
||||
|
if c.v != "" { |
||||
|
return c.v, nil |
||||
|
} |
||||
|
if c.trim { |
||||
|
v, err = getTrimmed(c.k) |
||||
|
} else { |
||||
|
v, err = Get(c.k) |
||||
|
} |
||||
|
if err == nil { |
||||
|
c.v = v |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
onGCEOnce sync.Once |
||||
|
onGCE bool |
||||
|
) |
||||
|
|
||||
|
// OnGCE reports whether this process is running on Google Compute Engine.
|
||||
|
func OnGCE() bool { |
||||
|
onGCEOnce.Do(initOnGCE) |
||||
|
return onGCE |
||||
|
} |
||||
|
|
||||
|
func initOnGCE() { |
||||
|
onGCE = testOnGCE() |
||||
|
} |
||||
|
|
||||
|
func testOnGCE() bool { |
||||
|
// The user explicitly said they're on GCE, so trust them.
|
||||
|
if os.Getenv(metadataHostEnv) != "" { |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
defer cancel() |
||||
|
|
||||
|
resc := make(chan bool, 2) |
||||
|
|
||||
|
// Try two strategies in parallel.
|
||||
|
// See https://github.com/GoogleCloudPlatform/gcloud-golang/issues/194
|
||||
|
go func() { |
||||
|
res, err := ctxhttp.Get(ctx, metaClient, "http://"+metadataIP) |
||||
|
if err != nil { |
||||
|
resc <- false |
||||
|
return |
||||
|
} |
||||
|
defer res.Body.Close() |
||||
|
resc <- res.Header.Get("Metadata-Flavor") == "Google" |
||||
|
}() |
||||
|
|
||||
|
go func() { |
||||
|
addrs, err := net.LookupHost("metadata.google.internal") |
||||
|
if err != nil || len(addrs) == 0 { |
||||
|
resc <- false |
||||
|
return |
||||
|
} |
||||
|
resc <- strsContains(addrs, metadataIP) |
||||
|
}() |
||||
|
|
||||
|
tryHarder := systemInfoSuggestsGCE() |
||||
|
if tryHarder { |
||||
|
res := <-resc |
||||
|
if res { |
||||
|
// The first strategy succeeded, so let's use it.
|
||||
|
return true |
||||
|
} |
||||
|
// Wait for either the DNS or metadata server probe to
|
||||
|
// contradict the other one and say we are running on
|
||||
|
// GCE. Give it a lot of time to do so, since the system
|
||||
|
// info already suggests we're running on a GCE BIOS.
|
||||
|
timer := time.NewTimer(5 * time.Second) |
||||
|
defer timer.Stop() |
||||
|
select { |
||||
|
case res = <-resc: |
||||
|
return res |
||||
|
case <-timer.C: |
||||
|
// Too slow. Who knows what this system is.
|
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// There's no hint from the system info that we're running on
|
||||
|
// GCE, so use the first probe's result as truth, whether it's
|
||||
|
// true or false. The goal here is to optimize for speed for
|
||||
|
// users who are NOT running on GCE. We can't assume that
|
||||
|
// either a DNS lookup or an HTTP request to a blackholed IP
|
||||
|
// address is fast. Worst case this should return when the
|
||||
|
// metaClient's Transport.ResponseHeaderTimeout or
|
||||
|
// Transport.Dial.Timeout fires (in two seconds).
|
||||
|
return <-resc |
||||
|
} |
||||
|
|
||||
|
// systemInfoSuggestsGCE reports whether the local system (without
|
||||
|
// doing network requests) suggests that we're running on GCE. If this
|
||||
|
// returns true, testOnGCE tries a bit harder to reach its metadata
|
||||
|
// server.
|
||||
|
func systemInfoSuggestsGCE() bool { |
||||
|
if runtime.GOOS != "linux" { |
||||
|
// We don't have any non-Linux clues available, at least yet.
|
||||
|
return false |
||||
|
} |
||||
|
slurp, _ := ioutil.ReadFile("/sys/class/dmi/id/product_name") |
||||
|
name := strings.TrimSpace(string(slurp)) |
||||
|
return name == "Google" || name == "Google Compute Engine" |
||||
|
} |
||||
|
|
||||
|
// Subscribe subscribes to a value from the metadata service.
|
||||
|
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
|
||||
|
// The suffix may contain query parameters.
|
||||
|
//
|
||||
|
// Subscribe calls fn with the latest metadata value indicated by the provided
|
||||
|
// suffix. If the metadata value is deleted, fn is called with the empty string
|
||||
|
// and ok false. Subscribe blocks until fn returns a non-nil error or the value
|
||||
|
// is deleted. Subscribe returns the error value returned from the last call to
|
||||
|
// fn, which may be nil when ok == false.
|
||||
|
func Subscribe(suffix string, fn func(v string, ok bool) error) error { |
||||
|
const failedSubscribeSleep = time.Second * 5 |
||||
|
|
||||
|
// First check to see if the metadata value exists at all.
|
||||
|
val, lastETag, err := getETag(subscribeClient, suffix) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
if err := fn(val, true); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
ok := true |
||||
|
if strings.ContainsRune(suffix, '?') { |
||||
|
suffix += "&wait_for_change=true&last_etag=" |
||||
|
} else { |
||||
|
suffix += "?wait_for_change=true&last_etag=" |
||||
|
} |
||||
|
for { |
||||
|
val, etag, err := getETag(subscribeClient, suffix+url.QueryEscape(lastETag)) |
||||
|
if err != nil { |
||||
|
if _, deleted := err.(NotDefinedError); !deleted { |
||||
|
time.Sleep(failedSubscribeSleep) |
||||
|
continue // Retry on other errors.
|
||||
|
} |
||||
|
ok = false |
||||
|
} |
||||
|
lastETag = etag |
||||
|
|
||||
|
if err := fn(val, ok); err != nil || !ok { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// ProjectID returns the current instance's project ID string.
|
||||
|
func ProjectID() (string, error) { return projID.get() } |
||||
|
|
||||
|
// NumericProjectID returns the current instance's numeric project ID.
|
||||
|
func NumericProjectID() (string, error) { return projNum.get() } |
||||
|
|
||||
|
// InternalIP returns the instance's primary internal IP address.
|
||||
|
func InternalIP() (string, error) { |
||||
|
return getTrimmed("instance/network-interfaces/0/ip") |
||||
|
} |
||||
|
|
||||
|
// ExternalIP returns the instance's primary external (public) IP address.
|
||||
|
func ExternalIP() (string, error) { |
||||
|
return getTrimmed("instance/network-interfaces/0/access-configs/0/external-ip") |
||||
|
} |
||||
|
|
||||
|
// Hostname returns the instance's hostname. This will be of the form
|
||||
|
// "<instanceID>.c.<projID>.internal".
|
||||
|
func Hostname() (string, error) { |
||||
|
return getTrimmed("instance/hostname") |
||||
|
} |
||||
|
|
||||
|
// InstanceTags returns the list of user-defined instance tags,
|
||||
|
// assigned when initially creating a GCE instance.
|
||||
|
func InstanceTags() ([]string, error) { |
||||
|
var s []string |
||||
|
j, err := Get("instance/tags") |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if err := json.NewDecoder(strings.NewReader(j)).Decode(&s); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return s, nil |
||||
|
} |
||||
|
|
||||
|
// InstanceID returns the current VM's numeric instance ID.
|
||||
|
func InstanceID() (string, error) { |
||||
|
return instID.get() |
||||
|
} |
||||
|
|
||||
|
// InstanceName returns the current VM's instance ID string.
|
||||
|
func InstanceName() (string, error) { |
||||
|
host, err := Hostname() |
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
return strings.Split(host, ".")[0], nil |
||||
|
} |
||||
|
|
||||
|
// Zone returns the current VM's zone, such as "us-central1-b".
|
||||
|
func Zone() (string, error) { |
||||
|
zone, err := getTrimmed("instance/zone") |
||||
|
// zone is of the form "projects/<projNum>/zones/<zoneName>".
|
||||
|
if err != nil { |
||||
|
return "", err |
||||
|
} |
||||
|
return zone[strings.LastIndex(zone, "/")+1:], nil |
||||
|
} |
||||
|
|
||||
|
// InstanceAttributes returns the list of user-defined attributes,
|
||||
|
// assigned when initially creating a GCE VM instance. The value of an
|
||||
|
// attribute can be obtained with InstanceAttributeValue.
|
||||
|
func InstanceAttributes() ([]string, error) { return lines("instance/attributes/") } |
||||
|
|
||||
|
// ProjectAttributes returns the list of user-defined attributes
|
||||
|
// applying to the project as a whole, not just this VM. The value of
|
||||
|
// an attribute can be obtained with ProjectAttributeValue.
|
||||
|
func ProjectAttributes() ([]string, error) { return lines("project/attributes/") } |
||||
|
|
||||
|
func lines(suffix string) ([]string, error) { |
||||
|
j, err := Get(suffix) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
s := strings.Split(strings.TrimSpace(j), "\n") |
||||
|
for i := range s { |
||||
|
s[i] = strings.TrimSpace(s[i]) |
||||
|
} |
||||
|
return s, nil |
||||
|
} |
||||
|
|
||||
|
// InstanceAttributeValue returns the value of the provided VM
|
||||
|
// instance attribute.
|
||||
|
//
|
||||
|
// If the requested attribute is not defined, the returned error will
|
||||
|
// be of type NotDefinedError.
|
||||
|
//
|
||||
|
// InstanceAttributeValue may return ("", nil) if the attribute was
|
||||
|
// defined to be the empty string.
|
||||
|
func InstanceAttributeValue(attr string) (string, error) { |
||||
|
return Get("instance/attributes/" + attr) |
||||
|
} |
||||
|
|
||||
|
// ProjectAttributeValue returns the value of the provided
|
||||
|
// project attribute.
|
||||
|
//
|
||||
|
// If the requested attribute is not defined, the returned error will
|
||||
|
// be of type NotDefinedError.
|
||||
|
//
|
||||
|
// ProjectAttributeValue may return ("", nil) if the attribute was
|
||||
|
// defined to be the empty string.
|
||||
|
func ProjectAttributeValue(attr string) (string, error) { |
||||
|
return Get("project/attributes/" + attr) |
||||
|
} |
||||
|
|
||||
|
// Scopes returns the service account scopes for the given account.
|
||||
|
// The account may be empty or the string "default" to use the instance's
|
||||
|
// main account.
|
||||
|
func Scopes(serviceAccount string) ([]string, error) { |
||||
|
if serviceAccount == "" { |
||||
|
serviceAccount = "default" |
||||
|
} |
||||
|
return lines("instance/service-accounts/" + serviceAccount + "/scopes") |
||||
|
} |
||||
|
|
||||
|
func strsContains(ss []string, s string) bool { |
||||
|
for _, v := range ss { |
||||
|
if v == s { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
return false |
||||
|
} |
@ -0,0 +1,48 @@ |
|||||
|
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
package metadata |
||||
|
|
||||
|
import ( |
||||
|
"os" |
||||
|
"sync" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestOnGCE_Stress(t *testing.T) { |
||||
|
if testing.Short() { |
||||
|
t.Skip("skipping in -short mode") |
||||
|
} |
||||
|
var last bool |
||||
|
for i := 0; i < 100; i++ { |
||||
|
onGCEOnce = sync.Once{} |
||||
|
|
||||
|
now := OnGCE() |
||||
|
if i > 0 && now != last { |
||||
|
t.Errorf("%d. changed from %v to %v", i, last, now) |
||||
|
} |
||||
|
last = now |
||||
|
} |
||||
|
t.Logf("OnGCE() = %v", last) |
||||
|
} |
||||
|
|
||||
|
func TestOnGCE_Force(t *testing.T) { |
||||
|
onGCEOnce = sync.Once{} |
||||
|
old := os.Getenv(metadataHostEnv) |
||||
|
defer os.Setenv(metadataHostEnv, old) |
||||
|
os.Setenv(metadataHostEnv, "127.0.0.1") |
||||
|
if !OnGCE() { |
||||
|
t.Error("OnGCE() = false; want true") |
||||
|
} |
||||
|
} |
@ -0,0 +1,257 @@ |
|||||
|
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
// Package bundler supports bundling (batching) of items. Bundling amortizes an
|
||||
|
// action with fixed costs over multiple items. For example, if an API provides
|
||||
|
// an RPC that accepts a list of items as input, but clients would prefer
|
||||
|
// adding items one at a time, then a Bundler can accept individual items from
|
||||
|
// the client and bundle many of them into a single RPC.
|
||||
|
package bundler |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"reflect" |
||||
|
"sync" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
DefaultDelayThreshold = time.Second |
||||
|
DefaultBundleCountThreshold = 10 |
||||
|
DefaultBundleByteThreshold = 1e6 // 1M
|
||||
|
DefaultBufferedByteLimit = 1e9 // 1G
|
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
// ErrOverflow indicates that Bundler's stored bytes exceeds its BufferedByteLimit.
|
||||
|
ErrOverflow = errors.New("bundler reached buffered byte limit") |
||||
|
|
||||
|
// ErrOversizedItem indicates that an item's size exceeds the maximum bundle size.
|
||||
|
ErrOversizedItem = errors.New("item size exceeds bundle byte limit") |
||||
|
) |
||||
|
|
||||
|
// A Bundler collects items added to it into a bundle until the bundle
|
||||
|
// exceeds a given size, then calls a user-provided function to handle the bundle.
|
||||
|
type Bundler struct { |
||||
|
// Starting from the time that the first message is added to a bundle, once
|
||||
|
// this delay has passed, handle the bundle. The default is DefaultDelayThreshold.
|
||||
|
DelayThreshold time.Duration |
||||
|
|
||||
|
// Once a bundle has this many items, handle the bundle. Since only one
|
||||
|
// item at a time is added to a bundle, no bundle will exceed this
|
||||
|
// threshold, so it also serves as a limit. The default is
|
||||
|
// DefaultBundleCountThreshold.
|
||||
|
BundleCountThreshold int |
||||
|
|
||||
|
// Once the number of bytes in current bundle reaches this threshold, handle
|
||||
|
// the bundle. The default is DefaultBundleByteThreshold. This triggers handling,
|
||||
|
// but does not cap the total size of a bundle.
|
||||
|
BundleByteThreshold int |
||||
|
|
||||
|
// The maximum size of a bundle, in bytes. Zero means unlimited.
|
||||
|
BundleByteLimit int |
||||
|
|
||||
|
// The maximum number of bytes that the Bundler will keep in memory before
|
||||
|
// returning ErrOverflow. The default is DefaultBufferedByteLimit.
|
||||
|
BufferedByteLimit int |
||||
|
|
||||
|
handler func(interface{}) // called to handle a bundle
|
||||
|
itemSliceZero reflect.Value // nil (zero value) for slice of items
|
||||
|
donec chan struct{} // closed when the Bundler is closed
|
||||
|
handlec chan int // sent to when a bundle is ready for handling
|
||||
|
timer *time.Timer // implements DelayThreshold
|
||||
|
|
||||
|
mu sync.Mutex |
||||
|
bufferedSize int // total bytes buffered
|
||||
|
closedBundles []bundle // bundles waiting to be handled
|
||||
|
curBundle bundle // incoming items added to this bundle
|
||||
|
calledc chan struct{} // closed and re-created after handler is called
|
||||
|
} |
||||
|
|
||||
|
type bundle struct { |
||||
|
items reflect.Value // slice of item type
|
||||
|
size int // size in bytes of all items
|
||||
|
} |
||||
|
|
||||
|
// NewBundler creates a new Bundler. When you are finished with a Bundler, call
|
||||
|
// its Close method.
|
||||
|
//
|
||||
|
// itemExample is a value of the type that will be bundled. For example, if you
|
||||
|
// want to create bundles of *Entry, you could pass &Entry{} for itemExample.
|
||||
|
//
|
||||
|
// handler is a function that will be called on each bundle. If itemExample is
|
||||
|
// of type T, the the argument to handler is of type []T.
|
||||
|
func NewBundler(itemExample interface{}, handler func(interface{})) *Bundler { |
||||
|
b := &Bundler{ |
||||
|
DelayThreshold: DefaultDelayThreshold, |
||||
|
BundleCountThreshold: DefaultBundleCountThreshold, |
||||
|
BundleByteThreshold: DefaultBundleByteThreshold, |
||||
|
BufferedByteLimit: DefaultBufferedByteLimit, |
||||
|
|
||||
|
handler: handler, |
||||
|
itemSliceZero: reflect.Zero(reflect.SliceOf(reflect.TypeOf(itemExample))), |
||||
|
donec: make(chan struct{}), |
||||
|
handlec: make(chan int, 1), |
||||
|
calledc: make(chan struct{}), |
||||
|
timer: time.NewTimer(1000 * time.Hour), // harmless initial timeout
|
||||
|
} |
||||
|
b.curBundle.items = b.itemSliceZero |
||||
|
go b.background() |
||||
|
return b |
||||
|
} |
||||
|
|
||||
|
// Add adds item to the current bundle. It marks the bundle for handling and
|
||||
|
// starts a new one if any of the thresholds or limits are exceeded.
|
||||
|
//
|
||||
|
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
||||
|
// the item can never be handled. Add returns ErrOversizedItem in this case.
|
||||
|
//
|
||||
|
// If adding the item would exceed the maximum memory allowed (Bundler.BufferedByteLimit),
|
||||
|
// Add returns ErrOverflow.
|
||||
|
//
|
||||
|
// Add never blocks.
|
||||
|
func (b *Bundler) Add(item interface{}, size int) error { |
||||
|
// If this item exceeds the maximum size of a bundle,
|
||||
|
// we can never send it.
|
||||
|
if b.BundleByteLimit > 0 && size > b.BundleByteLimit { |
||||
|
return ErrOversizedItem |
||||
|
} |
||||
|
b.mu.Lock() |
||||
|
defer b.mu.Unlock() |
||||
|
// If adding this item would exceed our allotted memory
|
||||
|
// footprint, we can't accept it.
|
||||
|
if b.bufferedSize+size > b.BufferedByteLimit { |
||||
|
return ErrOverflow |
||||
|
} |
||||
|
// If adding this item to the current bundle would cause it to exceed the
|
||||
|
// maximum bundle size, close the current bundle and start a new one.
|
||||
|
if b.BundleByteLimit > 0 && b.curBundle.size+size > b.BundleByteLimit { |
||||
|
b.closeAndHandleBundle() |
||||
|
} |
||||
|
// Add the item.
|
||||
|
b.curBundle.items = reflect.Append(b.curBundle.items, reflect.ValueOf(item)) |
||||
|
b.curBundle.size += size |
||||
|
b.bufferedSize += size |
||||
|
// If this is the first item in the bundle, restart the timer.
|
||||
|
if b.curBundle.items.Len() == 1 { |
||||
|
b.timer.Reset(b.DelayThreshold) |
||||
|
} |
||||
|
// If the current bundle equals the count threshold, close it.
|
||||
|
if b.curBundle.items.Len() == b.BundleCountThreshold { |
||||
|
b.closeAndHandleBundle() |
||||
|
} |
||||
|
// If the current bundle equals or exceeds the byte threshold, close it.
|
||||
|
if b.curBundle.size >= b.BundleByteThreshold { |
||||
|
b.closeAndHandleBundle() |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Flush waits until all items in the Bundler have been handled.
|
||||
|
func (b *Bundler) Flush() { |
||||
|
b.mu.Lock() |
||||
|
b.closeBundle() |
||||
|
// Unconditionally trigger the handling goroutine, to ensure calledc is closed
|
||||
|
// even if there are no outstanding bundles.
|
||||
|
select { |
||||
|
case b.handlec <- 1: |
||||
|
default: |
||||
|
} |
||||
|
calledc := b.calledc // remember locally, because it may change
|
||||
|
b.mu.Unlock() |
||||
|
<-calledc |
||||
|
} |
||||
|
|
||||
|
// Close calls Flush, then shuts down the Bundler. Close should always be
|
||||
|
// called on a Bundler when it is no longer needed. You must wait for all calls
|
||||
|
// to Add to complete before calling Close. Calling Add concurrently with Close
|
||||
|
// may result in the added items being ignored.
|
||||
|
func (b *Bundler) Close() { |
||||
|
b.Flush() |
||||
|
b.mu.Lock() |
||||
|
b.timer.Stop() |
||||
|
b.mu.Unlock() |
||||
|
close(b.donec) |
||||
|
} |
||||
|
|
||||
|
func (b *Bundler) closeAndHandleBundle() { |
||||
|
if b.closeBundle() { |
||||
|
// We have created a closed bundle.
|
||||
|
// Send to handlec without blocking.
|
||||
|
select { |
||||
|
case b.handlec <- 1: |
||||
|
default: |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// closeBundle finishes the current bundle, adds it to the list of closed
|
||||
|
// bundles and informs the background goroutine that there are bundles ready
|
||||
|
// for processing.
|
||||
|
//
|
||||
|
// This should always be called with b.mu held.
|
||||
|
func (b *Bundler) closeBundle() bool { |
||||
|
if b.curBundle.items.Len() == 0 { |
||||
|
return false |
||||
|
} |
||||
|
b.closedBundles = append(b.closedBundles, b.curBundle) |
||||
|
b.curBundle.items = b.itemSliceZero |
||||
|
b.curBundle.size = 0 |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
// background runs in a separate goroutine, waiting for events and handling
|
||||
|
// bundles.
|
||||
|
func (b *Bundler) background() { |
||||
|
done := false |
||||
|
for { |
||||
|
timedOut := false |
||||
|
// Wait for something to happen.
|
||||
|
select { |
||||
|
case <-b.handlec: |
||||
|
case <-b.donec: |
||||
|
done = true |
||||
|
case <-b.timer.C: |
||||
|
timedOut = true |
||||
|
} |
||||
|
// Handle closed bundles.
|
||||
|
b.mu.Lock() |
||||
|
if timedOut { |
||||
|
b.closeBundle() |
||||
|
} |
||||
|
buns := b.closedBundles |
||||
|
b.closedBundles = nil |
||||
|
// Closing calledc means we've sent all bundles. We need
|
||||
|
// a new channel for the next set of bundles, which may start
|
||||
|
// accumulating as soon as we release the lock.
|
||||
|
calledc := b.calledc |
||||
|
b.calledc = make(chan struct{}) |
||||
|
b.mu.Unlock() |
||||
|
for i, bun := range buns { |
||||
|
b.handler(bun.items.Interface()) |
||||
|
// Drop the bundle's items, reducing our memory footprint.
|
||||
|
buns[i].items = reflect.Value{} // buns[i] because bun is a copy
|
||||
|
// Note immediately that we have more space, so Adds that occur
|
||||
|
// during this loop will have a chance of succeeding.
|
||||
|
b.mu.Lock() |
||||
|
b.bufferedSize -= bun.size |
||||
|
b.mu.Unlock() |
||||
|
} |
||||
|
// Signal that we've sent all outstanding bundles.
|
||||
|
close(calledc) |
||||
|
if done { |
||||
|
break |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,209 @@ |
|||||
|
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
package bundler |
||||
|
|
||||
|
import ( |
||||
|
"reflect" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
func TestBundlerCount1(t *testing.T) { |
||||
|
// Unbundled case: one item per bundle.
|
||||
|
handler := &testHandler{} |
||||
|
b := NewBundler(int(0), handler.handle) |
||||
|
b.BundleCountThreshold = 1 |
||||
|
b.DelayThreshold = time.Second |
||||
|
|
||||
|
for i := 0; i < 3; i++ { |
||||
|
if err := b.Add(i, 1); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
b.Close() |
||||
|
got := handler.bundles |
||||
|
want := [][]int{{0}, {1}, {2}} |
||||
|
if !reflect.DeepEqual(got, want) { |
||||
|
t.Errorf("bundles: got %v, want %v", got, want) |
||||
|
} |
||||
|
// All bundles should have been handled "immediately": much less
|
||||
|
// than the delay threshold of 1s.
|
||||
|
tgot := quantizeTimes(handler.times, 100*time.Millisecond) |
||||
|
twant := []int{0, 0, 0} |
||||
|
if !reflect.DeepEqual(tgot, twant) { |
||||
|
t.Errorf("times: got %v, want %v", tgot, twant) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestBundlerCount3(t *testing.T) { |
||||
|
handler := &testHandler{} |
||||
|
b := NewBundler(int(0), handler.handle) |
||||
|
b.BundleCountThreshold = 3 |
||||
|
b.DelayThreshold = 1 * time.Millisecond |
||||
|
// Add 8 items.
|
||||
|
// The first two bundles of 3 should both be handled quickly.
|
||||
|
// The third bundle of 2 should not be handled for about 1 ms.
|
||||
|
for i := 0; i < 8; i++ { |
||||
|
if err := b.Add(i, 1); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
time.Sleep(5 * b.DelayThreshold) |
||||
|
// We should not need to close the bundler.
|
||||
|
|
||||
|
bgot := handler.bundles |
||||
|
bwant := [][]int{{0, 1, 2}, {3, 4, 5}, {6, 7}} |
||||
|
if !reflect.DeepEqual(bgot, bwant) { |
||||
|
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
||||
|
} |
||||
|
|
||||
|
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
||||
|
twant := []int{0, 0, 1} |
||||
|
if !reflect.DeepEqual(tgot, twant) { |
||||
|
t.Errorf("times: got %v, want %v", tgot, twant) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestBundlerByteThreshold(t *testing.T) { |
||||
|
handler := &testHandler{} |
||||
|
b := NewBundler(int(0), handler.handle) |
||||
|
b.BundleCountThreshold = 10 |
||||
|
b.BundleByteThreshold = 3 |
||||
|
add := func(i interface{}, s int) { |
||||
|
if err := b.Add(i, s); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
add(1, 1) |
||||
|
add(2, 2) |
||||
|
// Hit byte threshold: bundle = 1, 2
|
||||
|
add(3, 1) |
||||
|
add(4, 1) |
||||
|
add(5, 2) |
||||
|
// Passed byte threshold, but not limit: bundle = 3, 4, 5
|
||||
|
add(6, 1) |
||||
|
b.Close() |
||||
|
bgot := handler.bundles |
||||
|
bwant := [][]int{{1, 2}, {3, 4, 5}, {6}} |
||||
|
if !reflect.DeepEqual(bgot, bwant) { |
||||
|
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
||||
|
} |
||||
|
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
||||
|
twant := []int{0, 0, 0} |
||||
|
if !reflect.DeepEqual(tgot, twant) { |
||||
|
t.Errorf("times: got %v, want %v", tgot, twant) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestBundlerLimit(t *testing.T) { |
||||
|
handler := &testHandler{} |
||||
|
b := NewBundler(int(0), handler.handle) |
||||
|
b.BundleCountThreshold = 10 |
||||
|
b.BundleByteLimit = 3 |
||||
|
add := func(i interface{}, s int) { |
||||
|
if err := b.Add(i, s); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
add(1, 1) |
||||
|
add(2, 2) |
||||
|
// Hit byte limit: bundle = 1, 2
|
||||
|
add(3, 1) |
||||
|
add(4, 1) |
||||
|
add(5, 2) |
||||
|
// Exceeded byte limit: bundle = 3, 4
|
||||
|
add(6, 2) |
||||
|
// Exceeded byte limit: bundle = 5
|
||||
|
b.Close() |
||||
|
bgot := handler.bundles |
||||
|
bwant := [][]int{{1, 2}, {3, 4}, {5}, {6}} |
||||
|
if !reflect.DeepEqual(bgot, bwant) { |
||||
|
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
||||
|
} |
||||
|
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
||||
|
twant := []int{0, 0, 0, 0} |
||||
|
if !reflect.DeepEqual(tgot, twant) { |
||||
|
t.Errorf("times: got %v, want %v", tgot, twant) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestBundlerErrors(t *testing.T) { |
||||
|
// Use a handler that blocks forever, to force the bundler to run out of
|
||||
|
// memory.
|
||||
|
b := NewBundler(int(0), func(interface{}) { select {} }) |
||||
|
b.BundleByteLimit = 3 |
||||
|
b.BufferedByteLimit = 10 |
||||
|
|
||||
|
if got, want := b.Add(1, 4), ErrOversizedItem; got != want { |
||||
|
t.Fatalf("got %v, want %v", got, want) |
||||
|
} |
||||
|
|
||||
|
for i := 0; i < 5; i++ { |
||||
|
if err := b.Add(i, 2); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
if got, want := b.Add(5, 1), ErrOverflow; got != want { |
||||
|
t.Fatalf("got %v, want %v", got, want) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
type testHandler struct { |
||||
|
bundles [][]int |
||||
|
times []time.Time |
||||
|
} |
||||
|
|
||||
|
func (t *testHandler) handle(b interface{}) { |
||||
|
t.bundles = append(t.bundles, b.([]int)) |
||||
|
t.times = append(t.times, time.Now()) |
||||
|
} |
||||
|
|
||||
|
// Round times to the nearest q and express them as the number of q
|
||||
|
// since the first time.
|
||||
|
// E.g. if q is 100ms, then a time within 50ms of the first time
|
||||
|
// will be represented as 0, a time 150 to 250ms of the first time
|
||||
|
// we be represented as 1, etc.
|
||||
|
func quantizeTimes(times []time.Time, q time.Duration) []int { |
||||
|
var rs []int |
||||
|
for _, t := range times { |
||||
|
d := t.Sub(times[0]) |
||||
|
r := int((d + q/2) / q) |
||||
|
rs = append(rs, r) |
||||
|
} |
||||
|
return rs |
||||
|
} |
||||
|
|
||||
|
func TestQuantizeTimes(t *testing.T) { |
||||
|
quantum := 100 * time.Millisecond |
||||
|
for _, test := range []struct { |
||||
|
millis []int // times in milliseconds
|
||||
|
want []int |
||||
|
}{ |
||||
|
{[]int{10, 20, 30}, []int{0, 0, 0}}, |
||||
|
{[]int{0, 49, 50, 90}, []int{0, 0, 1, 1}}, |
||||
|
{[]int{0, 95, 170, 315}, []int{0, 1, 2, 3}}, |
||||
|
} { |
||||
|
var times []time.Time |
||||
|
for _, ms := range test.millis { |
||||
|
times = append(times, time.Unix(0, int64(ms*1e6))) |
||||
|
} |
||||
|
got := quantizeTimes(times, quantum) |
||||
|
if !reflect.DeepEqual(got, test.want) { |
||||
|
t.Errorf("%v: got %v, want %v", test.millis, got, test.want) |
||||
|
} |
||||
|
} |
||||
|
} |
@ -0,0 +1,128 @@ |
|||||
|
// Copyright 2014 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
// Package internal provides support for the cloud packages.
|
||||
|
//
|
||||
|
// Users should not import this package directly.
|
||||
|
package internal |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
"sync" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
) |
||||
|
|
||||
|
type contextKey struct{} |
||||
|
|
||||
|
func WithContext(parent context.Context, projID string, c *http.Client) context.Context { |
||||
|
if c == nil { |
||||
|
panic("nil *http.Client passed to WithContext") |
||||
|
} |
||||
|
if projID == "" { |
||||
|
panic("empty project ID passed to WithContext") |
||||
|
} |
||||
|
return context.WithValue(parent, contextKey{}, &cloudContext{ |
||||
|
ProjectID: projID, |
||||
|
HTTPClient: c, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
const userAgent = "gcloud-golang/0.1" |
||||
|
|
||||
|
type cloudContext struct { |
||||
|
ProjectID string |
||||
|
HTTPClient *http.Client |
||||
|
|
||||
|
mu sync.Mutex // guards svc
|
||||
|
svc map[string]interface{} // e.g. "storage" => *rawStorage.Service
|
||||
|
} |
||||
|
|
||||
|
// Service returns the result of the fill function if it's never been
|
||||
|
// called before for the given name (which is assumed to be an API
|
||||
|
// service name, like "datastore"). If it has already been cached, the fill
|
||||
|
// func is not run.
|
||||
|
// It's safe for concurrent use by multiple goroutines.
|
||||
|
func Service(ctx context.Context, name string, fill func(*http.Client) interface{}) interface{} { |
||||
|
return cc(ctx).service(name, fill) |
||||
|
} |
||||
|
|
||||
|
func (c *cloudContext) service(name string, fill func(*http.Client) interface{}) interface{} { |
||||
|
c.mu.Lock() |
||||
|
defer c.mu.Unlock() |
||||
|
|
||||
|
if c.svc == nil { |
||||
|
c.svc = make(map[string]interface{}) |
||||
|
} else if v, ok := c.svc[name]; ok { |
||||
|
return v |
||||
|
} |
||||
|
v := fill(c.HTTPClient) |
||||
|
c.svc[name] = v |
||||
|
return v |
||||
|
} |
||||
|
|
||||
|
// Transport is an http.RoundTripper that appends
|
||||
|
// Google Cloud client's user-agent to the original
|
||||
|
// request's user-agent header.
|
||||
|
type Transport struct { |
||||
|
// Base is the actual http.RoundTripper
|
||||
|
// requests will use. It must not be nil.
|
||||
|
Base http.RoundTripper |
||||
|
} |
||||
|
|
||||
|
// RoundTrip appends a user-agent to the existing user-agent
|
||||
|
// header and delegates the request to the base http.RoundTripper.
|
||||
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { |
||||
|
req = cloneRequest(req) |
||||
|
ua := req.Header.Get("User-Agent") |
||||
|
if ua == "" { |
||||
|
ua = userAgent |
||||
|
} else { |
||||
|
ua = fmt.Sprintf("%s %s", ua, userAgent) |
||||
|
} |
||||
|
req.Header.Set("User-Agent", ua) |
||||
|
return t.Base.RoundTrip(req) |
||||
|
} |
||||
|
|
||||
|
// cloneRequest returns a clone of the provided *http.Request.
|
||||
|
// The clone is a shallow copy of the struct and its Header map.
|
||||
|
func cloneRequest(r *http.Request) *http.Request { |
||||
|
// shallow copy of the struct
|
||||
|
r2 := new(http.Request) |
||||
|
*r2 = *r |
||||
|
// deep copy of the Header
|
||||
|
r2.Header = make(http.Header) |
||||
|
for k, s := range r.Header { |
||||
|
r2.Header[k] = s |
||||
|
} |
||||
|
return r2 |
||||
|
} |
||||
|
|
||||
|
func ProjID(ctx context.Context) string { |
||||
|
return cc(ctx).ProjectID |
||||
|
} |
||||
|
|
||||
|
func HTTPClient(ctx context.Context) *http.Client { |
||||
|
return cc(ctx).HTTPClient |
||||
|
} |
||||
|
|
||||
|
// cc returns the internal *cloudContext (cc) state for a context.Context.
|
||||
|
// It panics if the user did it wrong.
|
||||
|
func cc(ctx context.Context) *cloudContext { |
||||
|
if c, ok := ctx.Value(contextKey{}).(*cloudContext); ok { |
||||
|
return c |
||||
|
} |
||||
|
panic("invalid context.Context type; it should be created with cloud.NewContext") |
||||
|
} |
@ -0,0 +1,60 @@ |
|||||
|
// Copyright 2014 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
// Package testutil contains helper functions for writing tests.
|
||||
|
package testutil |
||||
|
|
||||
|
import ( |
||||
|
"io/ioutil" |
||||
|
"log" |
||||
|
"os" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"golang.org/x/oauth2" |
||||
|
"golang.org/x/oauth2/google" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
envProjID = "GCLOUD_TESTS_GOLANG_PROJECT_ID" |
||||
|
envPrivateKey = "GCLOUD_TESTS_GOLANG_KEY" |
||||
|
) |
||||
|
|
||||
|
// ProjID returns the project ID to use in integration tests, or the empty
|
||||
|
// string if none is configured.
|
||||
|
func ProjID() string { |
||||
|
projID := os.Getenv(envProjID) |
||||
|
if projID == "" { |
||||
|
return "" |
||||
|
} |
||||
|
return projID |
||||
|
} |
||||
|
|
||||
|
// TokenSource returns the OAuth2 token source to use in integration tests,
|
||||
|
// or nil if none is configured. TokenSource will log.Fatal if the token
|
||||
|
// source is specified but missing or invalid.
|
||||
|
func TokenSource(ctx context.Context, scopes ...string) oauth2.TokenSource { |
||||
|
key := os.Getenv(envPrivateKey) |
||||
|
if key == "" { |
||||
|
return nil |
||||
|
} |
||||
|
jsonKey, err := ioutil.ReadFile(key) |
||||
|
if err != nil { |
||||
|
log.Fatalf("Cannot read the JSON key file, err: %v", err) |
||||
|
} |
||||
|
conf, err := google.JWTConfigFromJSON(jsonKey, scopes...) |
||||
|
if err != nil { |
||||
|
log.Fatalf("google.JWTConfigFromJSON: %v", err) |
||||
|
} |
||||
|
return conf.TokenSource(ctx) |
||||
|
} |
@ -0,0 +1,186 @@ |
|||||
|
/* |
||||
|
Copyright 2016 Google Inc. All Rights Reserved. |
||||
|
|
||||
|
Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
you may not use this file except in compliance with the License. |
||||
|
You may obtain a copy of the License at |
||||
|
|
||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
|
||||
|
Unless required by applicable law or agreed to in writing, software |
||||
|
distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
See the License for the specific language governing permissions and |
||||
|
limitations under the License. |
||||
|
*/ |
||||
|
|
||||
|
package testutil |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"reflect" |
||||
|
) |
||||
|
|
||||
|
// TestIteratorNext tests the Next method of a standard client iterator (see
|
||||
|
// https://github.com/GoogleCloudPlatform/gcloud-golang/wiki/Iterator-Guidelines).
|
||||
|
//
|
||||
|
// This function assumes that an iterator has already been created, and that
|
||||
|
// the underlying sequence to be iterated over already exists. want should be a
|
||||
|
// slice that contains the elements of this sequence. It must contain at least
|
||||
|
// one element.
|
||||
|
//
|
||||
|
// done is the special error value that signals that the iterator has provided all the items.
|
||||
|
//
|
||||
|
// next is a function that returns the result of calling Next on the iterator. It can usually be
|
||||
|
// defined as
|
||||
|
// func() (interface{}, error) { return iter.Next() }
|
||||
|
//
|
||||
|
// TestIteratorNext checks that the iterator returns all the elements of want
|
||||
|
// in order, followed by (zero, done). It also confirms that subsequent calls
|
||||
|
// to next also return (zero, done).
|
||||
|
//
|
||||
|
// On success, TestIteratorNext returns ("", true). On failure, it returns a
|
||||
|
// suitable error message and false.
|
||||
|
func TestIteratorNext(want interface{}, done error, next func() (interface{}, error)) (string, bool) { |
||||
|
wVal := reflect.ValueOf(want) |
||||
|
if wVal.Kind() != reflect.Slice { |
||||
|
return "'want' must be a slice", false |
||||
|
} |
||||
|
for i := 0; i < wVal.Len(); i++ { |
||||
|
got, err := next() |
||||
|
if err != nil { |
||||
|
return fmt.Sprintf("#%d: got %v, expected an item", i, err), false |
||||
|
} |
||||
|
w := wVal.Index(i).Interface() |
||||
|
if !reflect.DeepEqual(got, w) { |
||||
|
return fmt.Sprintf("#%d: got %+v, want %+v", i, got, w), false |
||||
|
} |
||||
|
} |
||||
|
// We now should see (<zero value of item type>, done), no matter how many
|
||||
|
// additional calls we make.
|
||||
|
zero := reflect.Zero(wVal.Type().Elem()).Interface() |
||||
|
for i := 0; i < 3; i++ { |
||||
|
got, err := next() |
||||
|
if err != done { |
||||
|
return fmt.Sprintf("at end: got error %v, want done", err), false |
||||
|
} |
||||
|
// Since err == done, got should be zero.
|
||||
|
if got != zero { |
||||
|
return fmt.Sprintf("got %+v with done, want zero %T", got, zero), false |
||||
|
} |
||||
|
} |
||||
|
return "", true |
||||
|
} |
||||
|
|
||||
|
// PagingIterator describes the standard client iterator pattern with paging as best as possible in Go.
|
||||
|
// See https://github.com/GoogleCloudPlatform/gcloud-golang/wiki/Iterator-Guidelines.
|
||||
|
type PagingIterator interface { |
||||
|
SetPageSize(int32) |
||||
|
SetPageToken(string) |
||||
|
NextPageToken() string |
||||
|
// NextPage() ([]T, error)
|
||||
|
} |
||||
|
|
||||
|
// TestIteratorNextPageExact tests the NextPage method of a standard client
|
||||
|
// iterator with paging (see PagingIterator).
|
||||
|
//
|
||||
|
// This function assumes that the underlying sequence to be iterated over
|
||||
|
// already exists. want should be a slice that contains the elements of this
|
||||
|
// sequence. It must contain at least three elements, in order to test
|
||||
|
// non-trivial paging behavior.
|
||||
|
//
|
||||
|
// done is the special error value that signals that the iterator has provided all the items.
|
||||
|
//
|
||||
|
// defaultPageSize is the page size to use when the user has not called SetPageSize, or calls
|
||||
|
// it with a value <= 0.
|
||||
|
//
|
||||
|
// newIter should return a new iterator each time it is called.
|
||||
|
//
|
||||
|
// nextPage should return the result of calling NextPage on the iterator. It can usually be
|
||||
|
// defined as
|
||||
|
// func(i testutil.PagingIterator) (interface{}, error) { return i.(*<iteratorType>).NextPage() }
|
||||
|
//
|
||||
|
// TestIteratorNextPageExact checks that the iterator returns all the elements
|
||||
|
// of want in order, divided into pages of the exactly the right size. It
|
||||
|
// confirms that if the last page is partial, done is returned along with it,
|
||||
|
// and in any case, done is returned subsequently along with a zero-length
|
||||
|
// slice.
|
||||
|
//
|
||||
|
// On success, TestIteratorNextPageExact returns ("", true). On failure, it returns a
|
||||
|
// suitable error message and false.
|
||||
|
func TestIteratorNextPageExact(want interface{}, done error, defaultPageSize int32, newIter func() PagingIterator, nextPage func(PagingIterator) (interface{}, error)) (string, bool) { |
||||
|
wVal := reflect.ValueOf(want) |
||||
|
if wVal.Kind() != reflect.Slice { |
||||
|
return "'want' must be a slice", false |
||||
|
} |
||||
|
if wVal.Len() < 3 { |
||||
|
return "need at least 3 values for 'want' to effectively test paging", false |
||||
|
} |
||||
|
const doNotSetPageSize = -999 |
||||
|
for _, pageSize := range []int32{doNotSetPageSize, -7, 0, 1, 2, int32(wVal.Len()), int32(wVal.Len()) + 10} { |
||||
|
adjustedPageSize := int(pageSize) |
||||
|
if pageSize <= 0 { |
||||
|
adjustedPageSize = int(defaultPageSize) |
||||
|
} |
||||
|
// Create the pages we expect to see.
|
||||
|
var wantPages []interface{} |
||||
|
for i, j := 0, adjustedPageSize; i < wVal.Len(); i, j = j, j+adjustedPageSize { |
||||
|
if j > wVal.Len() { |
||||
|
j = wVal.Len() |
||||
|
} |
||||
|
wantPages = append(wantPages, wVal.Slice(i, j).Interface()) |
||||
|
} |
||||
|
for _, usePageToken := range []bool{false, true} { |
||||
|
it := newIter() |
||||
|
if pageSize != doNotSetPageSize { |
||||
|
it.SetPageSize(pageSize) |
||||
|
} |
||||
|
for i, wantPage := range wantPages { |
||||
|
gotPage, err := nextPage(it) |
||||
|
if err != nil && err != done { |
||||
|
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d: got %v, expected a page", |
||||
|
usePageToken, pageSize, i, err), false |
||||
|
} |
||||
|
if !reflect.DeepEqual(gotPage, wantPage) { |
||||
|
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d:\ngot %v\nwant %+v", |
||||
|
usePageToken, pageSize, i, gotPage, wantPage), false |
||||
|
} |
||||
|
// If the last page is partial, NextPage must return done.
|
||||
|
if reflect.ValueOf(gotPage).Len() < adjustedPageSize && err != done { |
||||
|
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d: expected done on partial page, got %v", |
||||
|
usePageToken, pageSize, i, err), false |
||||
|
} |
||||
|
if usePageToken { |
||||
|
// Pretend that we are displaying a paginated listing on the web, and the next
|
||||
|
// page may be served by a different process.
|
||||
|
// Empty page token implies done, and vice versa.
|
||||
|
if (it.NextPageToken() == "") != (err == done) { |
||||
|
return fmt.Sprintf("pageSize %d: next page token = %q and err = %v; expected empty page token iff done", |
||||
|
pageSize, it.NextPageToken(), err), false |
||||
|
} |
||||
|
if err == nil { |
||||
|
token := it.NextPageToken() |
||||
|
it = newIter() |
||||
|
it.SetPageSize(pageSize) |
||||
|
it.SetPageToken(token) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
// We now should see (<zero-length or nil slice>, done), no matter how many
|
||||
|
// additional calls we make.
|
||||
|
for i := 0; i < 3; i++ { |
||||
|
gotPage, err := nextPage(it) |
||||
|
if err != done { |
||||
|
return fmt.Sprintf("usePageToken %v, pageSize %d, at end: got error %v, want done", |
||||
|
usePageToken, pageSize, err), false |
||||
|
} |
||||
|
pVal := reflect.ValueOf(gotPage) |
||||
|
if pVal.Kind() != reflect.Slice || pVal.Len() != 0 { |
||||
|
return fmt.Sprintf("usePageToken %v, pageSize %d, at end: got %+v with done, want zero-length slice", |
||||
|
usePageToken, pageSize, gotPage), false |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
return "", true |
||||
|
} |
@ -0,0 +1,73 @@ |
|||||
|
/* |
||||
|
Copyright 2016 Google Inc. All Rights Reserved. |
||||
|
|
||||
|
Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
you may not use this file except in compliance with the License. |
||||
|
You may obtain a copy of the License at |
||||
|
|
||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
|
||||
|
Unless required by applicable law or agreed to in writing, software |
||||
|
distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
See the License for the specific language governing permissions and |
||||
|
limitations under the License. |
||||
|
*/ |
||||
|
|
||||
|
package testutil |
||||
|
|
||||
|
import ( |
||||
|
"net" |
||||
|
|
||||
|
grpc "google.golang.org/grpc" |
||||
|
) |
||||
|
|
||||
|
// A Server is an in-process gRPC server, listening on a system-chosen port on
|
||||
|
// the local loopback interface. Servers are for testing only and are not
|
||||
|
// intended to be used in production code.
|
||||
|
//
|
||||
|
// To create a server, make a new Server, register your handlers, then call
|
||||
|
// Start:
|
||||
|
//
|
||||
|
// srv, err := NewServer()
|
||||
|
// ...
|
||||
|
// mypb.RegisterMyServiceServer(srv.Gsrv, &myHandler)
|
||||
|
// ....
|
||||
|
// srv.Start()
|
||||
|
//
|
||||
|
// Clients should connect to the server with no security:
|
||||
|
//
|
||||
|
// conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
|
||||
|
// ...
|
||||
|
type Server struct { |
||||
|
Addr string |
||||
|
l net.Listener |
||||
|
Gsrv *grpc.Server |
||||
|
} |
||||
|
|
||||
|
// NewServer creates a new Server. The Server will be listening for gRPC connections
|
||||
|
// at the address named by the Addr field, without TLS.
|
||||
|
func NewServer() (*Server, error) { |
||||
|
l, err := net.Listen("tcp", "127.0.0.1:0") |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
s := &Server{ |
||||
|
Addr: l.Addr().String(), |
||||
|
l: l, |
||||
|
Gsrv: grpc.NewServer(), |
||||
|
} |
||||
|
return s, nil |
||||
|
} |
||||
|
|
||||
|
// Start causes the server to start accepting incoming connections.
|
||||
|
// Call Start after registering handlers.
|
||||
|
func (s *Server) Start() { |
||||
|
go s.Gsrv.Serve(s.l) |
||||
|
} |
||||
|
|
||||
|
// Close shuts down the server.
|
||||
|
func (s *Server) Close() { |
||||
|
s.Gsrv.Stop() |
||||
|
s.l.Close() |
||||
|
} |
@ -0,0 +1,35 @@ |
|||||
|
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
package testutil |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
|
||||
|
grpc "google.golang.org/grpc" |
||||
|
) |
||||
|
|
||||
|
func TestNewServer(t *testing.T) { |
||||
|
srv, err := NewServer() |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
srv.Start() |
||||
|
conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure()) |
||||
|
if err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
conn.Close() |
||||
|
srv.Close() |
||||
|
} |
@ -0,0 +1,61 @@ |
|||||
|
// Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
//
|
||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
// you may not use this file except in compliance with the License.
|
||||
|
// You may obtain a copy of the License at
|
||||
|
//
|
||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
//
|
||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
// See the License for the specific language governing permissions and
|
||||
|
// limitations under the License.
|
||||
|
|
||||
|
package transport |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"net/http" |
||||
|
|
||||
|
"golang.org/x/net/context" |
||||
|
"google.golang.org/api/option" |
||||
|
"google.golang.org/api/transport" |
||||
|
"google.golang.org/cloud" |
||||
|
"google.golang.org/grpc" |
||||
|
) |
||||
|
|
||||
|
// ErrHTTP is returned when on a non-200 HTTP response.
|
||||
|
type ErrHTTP struct { |
||||
|
StatusCode int |
||||
|
Body []byte |
||||
|
err error |
||||
|
} |
||||
|
|
||||
|
func (e *ErrHTTP) Error() string { |
||||
|
if e.err == nil { |
||||
|
return fmt.Sprintf("error during call, http status code: %v %s", e.StatusCode, e.Body) |
||||
|
} |
||||
|
return e.err.Error() |
||||
|
} |
||||
|
|
||||
|
// NewHTTPClient returns an HTTP client for use communicating with a Google cloud
|
||||
|
// service, configured with the given ClientOptions. It also returns the endpoint
|
||||
|
// for the service as specified in the options.
|
||||
|
func NewHTTPClient(ctx context.Context, opt ...cloud.ClientOption) (*http.Client, string, error) { |
||||
|
o := make([]option.ClientOption, 0, len(opt)) |
||||
|
for _, opt := range opt { |
||||
|
o = append(o, opt.Resolve()) |
||||
|
} |
||||
|
return transport.NewHTTPClient(ctx, o...) |
||||
|
} |
||||
|
|
||||
|
// DialGRPC returns a GRPC connection for use communicating with a Google cloud
|
||||
|
// service, configured with the given ClientOptions.
|
||||
|
func DialGRPC(ctx context.Context, opt ...cloud.ClientOption) (*grpc.ClientConn, error) { |
||||
|
o := make([]option.ClientOption, 0, len(opt)) |
||||
|
for _, opt := range opt { |
||||
|
o = append(o, opt.Resolve()) |
||||
|
} |
||||
|
return transport.DialGRPC(ctx, o...) |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue