mirror of https://github.com/matrix-org/go-neb.git
Kegan Dougal
9 years ago
71 changed files with 7198 additions and 0 deletions
-
27vendor/manifest
-
156vendor/src/golang.org/x/net/context/context.go
-
577vendor/src/golang.org/x/net/context/context_test.go
-
74vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp.go
-
28vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go
-
147vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go
-
79vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go
-
105vendor/src/golang.org/x/net/context/ctxhttp/ctxhttp_test.go
-
72vendor/src/golang.org/x/net/context/go17.go
-
300vendor/src/golang.org/x/net/context/pre_go17.go
-
26vendor/src/golang.org/x/net/context/withtimeout_test.go
-
3vendor/src/golang.org/x/oauth2/AUTHORS
-
31vendor/src/golang.org/x/oauth2/CONTRIBUTING.md
-
3vendor/src/golang.org/x/oauth2/CONTRIBUTORS
-
27vendor/src/golang.org/x/oauth2/LICENSE
-
64vendor/src/golang.org/x/oauth2/README.md
-
16vendor/src/golang.org/x/oauth2/bitbucket/bitbucket.go
-
25vendor/src/golang.org/x/oauth2/client_appengine.go
-
112vendor/src/golang.org/x/oauth2/clientcredentials/clientcredentials.go
-
96vendor/src/golang.org/x/oauth2/clientcredentials/clientcredentials_test.go
-
45vendor/src/golang.org/x/oauth2/example_test.go
-
16vendor/src/golang.org/x/oauth2/facebook/facebook.go
-
16vendor/src/golang.org/x/oauth2/fitbit/fitbit.go
-
16vendor/src/golang.org/x/oauth2/github/github.go
-
86vendor/src/golang.org/x/oauth2/google/appengine.go
-
13vendor/src/golang.org/x/oauth2/google/appengine_hook.go
-
14vendor/src/golang.org/x/oauth2/google/appenginevm_hook.go
-
155vendor/src/golang.org/x/oauth2/google/default.go
-
150vendor/src/golang.org/x/oauth2/google/example_test.go
-
149vendor/src/golang.org/x/oauth2/google/google.go
-
97vendor/src/golang.org/x/oauth2/google/google_test.go
-
74vendor/src/golang.org/x/oauth2/google/jwt.go
-
91vendor/src/golang.org/x/oauth2/google/jwt_test.go
-
168vendor/src/golang.org/x/oauth2/google/sdk.go
-
46vendor/src/golang.org/x/oauth2/google/sdk_test.go
-
122vendor/src/golang.org/x/oauth2/google/testdata/gcloud/credentials
-
2vendor/src/golang.org/x/oauth2/google/testdata/gcloud/properties
-
60vendor/src/golang.org/x/oauth2/hipchat/hipchat.go
-
76vendor/src/golang.org/x/oauth2/internal/oauth2.go
-
62vendor/src/golang.org/x/oauth2/internal/oauth2_test.go
-
225vendor/src/golang.org/x/oauth2/internal/token.go
-
36vendor/src/golang.org/x/oauth2/internal/token_test.go
-
69vendor/src/golang.org/x/oauth2/internal/transport.go
-
38vendor/src/golang.org/x/oauth2/internal/transport_test.go
-
174vendor/src/golang.org/x/oauth2/jws/jws.go
-
46vendor/src/golang.org/x/oauth2/jws/jws_test.go
-
31vendor/src/golang.org/x/oauth2/jwt/example_test.go
-
157vendor/src/golang.org/x/oauth2/jwt/jwt.go
-
134vendor/src/golang.org/x/oauth2/jwt/jwt_test.go
-
16vendor/src/golang.org/x/oauth2/linkedin/linkedin.go
-
16vendor/src/golang.org/x/oauth2/microsoft/microsoft.go
-
337vendor/src/golang.org/x/oauth2/oauth2.go
-
458vendor/src/golang.org/x/oauth2/oauth2_test.go
-
16vendor/src/golang.org/x/oauth2/odnoklassniki/odnoklassniki.go
-
22vendor/src/golang.org/x/oauth2/paypal/paypal.go
-
16vendor/src/golang.org/x/oauth2/slack/slack.go
-
158vendor/src/golang.org/x/oauth2/token.go
-
72vendor/src/golang.org/x/oauth2/token_test.go
-
132vendor/src/golang.org/x/oauth2/transport.go
-
108vendor/src/golang.org/x/oauth2/transport_test.go
-
16vendor/src/golang.org/x/oauth2/vk/vk.go
-
438vendor/src/google.golang.org/cloud/compute/metadata/metadata.go
-
48vendor/src/google.golang.org/cloud/compute/metadata/metadata_test.go
-
257vendor/src/google.golang.org/cloud/internal/bundler/bundler.go
-
209vendor/src/google.golang.org/cloud/internal/bundler/bundler_test.go
-
128vendor/src/google.golang.org/cloud/internal/cloud.go
-
60vendor/src/google.golang.org/cloud/internal/testutil/context.go
-
186vendor/src/google.golang.org/cloud/internal/testutil/iterators.go
-
73vendor/src/google.golang.org/cloud/internal/testutil/server.go
-
35vendor/src/google.golang.org/cloud/internal/testutil/server_test.go
-
61vendor/src/google.golang.org/cloud/internal/transport/dial.go
@ -0,0 +1,156 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package context defines the Context type, which carries deadlines,
|
|||
// cancelation signals, and other request-scoped values across API boundaries
|
|||
// and between processes.
|
|||
//
|
|||
// Incoming requests to a server should create a Context, and outgoing calls to
|
|||
// servers should accept a Context. The chain of function calls between must
|
|||
// propagate the Context, optionally replacing it with a modified copy created
|
|||
// using WithDeadline, WithTimeout, WithCancel, or WithValue.
|
|||
//
|
|||
// Programs that use Contexts should follow these rules to keep interfaces
|
|||
// consistent across packages and enable static analysis tools to check context
|
|||
// propagation:
|
|||
//
|
|||
// Do not store Contexts inside a struct type; instead, pass a Context
|
|||
// explicitly to each function that needs it. The Context should be the first
|
|||
// parameter, typically named ctx:
|
|||
//
|
|||
// func DoSomething(ctx context.Context, arg Arg) error {
|
|||
// // ... use ctx ...
|
|||
// }
|
|||
//
|
|||
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
|
|||
// if you are unsure about which Context to use.
|
|||
//
|
|||
// Use context Values only for request-scoped data that transits processes and
|
|||
// APIs, not for passing optional parameters to functions.
|
|||
//
|
|||
// The same Context may be passed to functions running in different goroutines;
|
|||
// Contexts are safe for simultaneous use by multiple goroutines.
|
|||
//
|
|||
// See http://blog.golang.org/context for example code for a server that uses
|
|||
// Contexts.
|
|||
package context // import "golang.org/x/net/context"
|
|||
|
|||
import "time" |
|||
|
|||
// A Context carries a deadline, a cancelation signal, and other values across
|
|||
// API boundaries.
|
|||
//
|
|||
// Context's methods may be called by multiple goroutines simultaneously.
|
|||
type Context interface { |
|||
// Deadline returns the time when work done on behalf of this context
|
|||
// should be canceled. Deadline returns ok==false when no deadline is
|
|||
// set. Successive calls to Deadline return the same results.
|
|||
Deadline() (deadline time.Time, ok bool) |
|||
|
|||
// Done returns a channel that's closed when work done on behalf of this
|
|||
// context should be canceled. Done may return nil if this context can
|
|||
// never be canceled. Successive calls to Done return the same value.
|
|||
//
|
|||
// WithCancel arranges for Done to be closed when cancel is called;
|
|||
// WithDeadline arranges for Done to be closed when the deadline
|
|||
// expires; WithTimeout arranges for Done to be closed when the timeout
|
|||
// elapses.
|
|||
//
|
|||
// Done is provided for use in select statements:
|
|||
//
|
|||
// // Stream generates values with DoSomething and sends them to out
|
|||
// // until DoSomething returns an error or ctx.Done is closed.
|
|||
// func Stream(ctx context.Context, out chan<- Value) error {
|
|||
// for {
|
|||
// v, err := DoSomething(ctx)
|
|||
// if err != nil {
|
|||
// return err
|
|||
// }
|
|||
// select {
|
|||
// case <-ctx.Done():
|
|||
// return ctx.Err()
|
|||
// case out <- v:
|
|||
// }
|
|||
// }
|
|||
// }
|
|||
//
|
|||
// See http://blog.golang.org/pipelines for more examples of how to use
|
|||
// a Done channel for cancelation.
|
|||
Done() <-chan struct{} |
|||
|
|||
// Err returns a non-nil error value after Done is closed. Err returns
|
|||
// Canceled if the context was canceled or DeadlineExceeded if the
|
|||
// context's deadline passed. No other values for Err are defined.
|
|||
// After Done is closed, successive calls to Err return the same value.
|
|||
Err() error |
|||
|
|||
// Value returns the value associated with this context for key, or nil
|
|||
// if no value is associated with key. Successive calls to Value with
|
|||
// the same key returns the same result.
|
|||
//
|
|||
// Use context values only for request-scoped data that transits
|
|||
// processes and API boundaries, not for passing optional parameters to
|
|||
// functions.
|
|||
//
|
|||
// A key identifies a specific value in a Context. Functions that wish
|
|||
// to store values in Context typically allocate a key in a global
|
|||
// variable then use that key as the argument to context.WithValue and
|
|||
// Context.Value. A key can be any type that supports equality;
|
|||
// packages should define keys as an unexported type to avoid
|
|||
// collisions.
|
|||
//
|
|||
// Packages that define a Context key should provide type-safe accessors
|
|||
// for the values stores using that key:
|
|||
//
|
|||
// // Package user defines a User type that's stored in Contexts.
|
|||
// package user
|
|||
//
|
|||
// import "golang.org/x/net/context"
|
|||
//
|
|||
// // User is the type of value stored in the Contexts.
|
|||
// type User struct {...}
|
|||
//
|
|||
// // key is an unexported type for keys defined in this package.
|
|||
// // This prevents collisions with keys defined in other packages.
|
|||
// type key int
|
|||
//
|
|||
// // userKey is the key for user.User values in Contexts. It is
|
|||
// // unexported; clients use user.NewContext and user.FromContext
|
|||
// // instead of using this key directly.
|
|||
// var userKey key = 0
|
|||
//
|
|||
// // NewContext returns a new Context that carries value u.
|
|||
// func NewContext(ctx context.Context, u *User) context.Context {
|
|||
// return context.WithValue(ctx, userKey, u)
|
|||
// }
|
|||
//
|
|||
// // FromContext returns the User value stored in ctx, if any.
|
|||
// func FromContext(ctx context.Context) (*User, bool) {
|
|||
// u, ok := ctx.Value(userKey).(*User)
|
|||
// return u, ok
|
|||
// }
|
|||
Value(key interface{}) interface{} |
|||
} |
|||
|
|||
// Background returns a non-nil, empty Context. It is never canceled, has no
|
|||
// values, and has no deadline. It is typically used by the main function,
|
|||
// initialization, and tests, and as the top-level Context for incoming
|
|||
// requests.
|
|||
func Background() Context { |
|||
return background |
|||
} |
|||
|
|||
// TODO returns a non-nil, empty Context. Code should use context.TODO when
|
|||
// it's unclear which Context to use or it is not yet available (because the
|
|||
// surrounding function has not yet been extended to accept a Context
|
|||
// parameter). TODO is recognized by static analysis tools that determine
|
|||
// whether Contexts are propagated correctly in a program.
|
|||
func TODO() Context { |
|||
return todo |
|||
} |
|||
|
|||
// A CancelFunc tells an operation to abandon its work.
|
|||
// A CancelFunc does not wait for the work to stop.
|
|||
// After the first call, subsequent calls to a CancelFunc do nothing.
|
|||
type CancelFunc func() |
@ -0,0 +1,577 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !go1.7
|
|||
|
|||
package context |
|||
|
|||
import ( |
|||
"fmt" |
|||
"math/rand" |
|||
"runtime" |
|||
"strings" |
|||
"sync" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
// otherContext is a Context that's not one of the types defined in context.go.
|
|||
// This lets us test code paths that differ based on the underlying type of the
|
|||
// Context.
|
|||
type otherContext struct { |
|||
Context |
|||
} |
|||
|
|||
func TestBackground(t *testing.T) { |
|||
c := Background() |
|||
if c == nil { |
|||
t.Fatalf("Background returned nil") |
|||
} |
|||
select { |
|||
case x := <-c.Done(): |
|||
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
if got, want := fmt.Sprint(c), "context.Background"; got != want { |
|||
t.Errorf("Background().String() = %q want %q", got, want) |
|||
} |
|||
} |
|||
|
|||
func TestTODO(t *testing.T) { |
|||
c := TODO() |
|||
if c == nil { |
|||
t.Fatalf("TODO returned nil") |
|||
} |
|||
select { |
|||
case x := <-c.Done(): |
|||
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
if got, want := fmt.Sprint(c), "context.TODO"; got != want { |
|||
t.Errorf("TODO().String() = %q want %q", got, want) |
|||
} |
|||
} |
|||
|
|||
func TestWithCancel(t *testing.T) { |
|||
c1, cancel := WithCancel(Background()) |
|||
|
|||
if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { |
|||
t.Errorf("c1.String() = %q want %q", got, want) |
|||
} |
|||
|
|||
o := otherContext{c1} |
|||
c2, _ := WithCancel(o) |
|||
contexts := []Context{c1, o, c2} |
|||
|
|||
for i, c := range contexts { |
|||
if d := c.Done(); d == nil { |
|||
t.Errorf("c[%d].Done() == %v want non-nil", i, d) |
|||
} |
|||
if e := c.Err(); e != nil { |
|||
t.Errorf("c[%d].Err() == %v want nil", i, e) |
|||
} |
|||
|
|||
select { |
|||
case x := <-c.Done(): |
|||
t.Errorf("<-c.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
} |
|||
|
|||
cancel() |
|||
time.Sleep(100 * time.Millisecond) // let cancelation propagate
|
|||
|
|||
for i, c := range contexts { |
|||
select { |
|||
case <-c.Done(): |
|||
default: |
|||
t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) |
|||
} |
|||
if e := c.Err(); e != Canceled { |
|||
t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestParentFinishesChild(t *testing.T) { |
|||
// Context tree:
|
|||
// parent -> cancelChild
|
|||
// parent -> valueChild -> timerChild
|
|||
parent, cancel := WithCancel(Background()) |
|||
cancelChild, stop := WithCancel(parent) |
|||
defer stop() |
|||
valueChild := WithValue(parent, "key", "value") |
|||
timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) |
|||
defer stop() |
|||
|
|||
select { |
|||
case x := <-parent.Done(): |
|||
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
|||
case x := <-cancelChild.Done(): |
|||
t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) |
|||
case x := <-timerChild.Done(): |
|||
t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) |
|||
case x := <-valueChild.Done(): |
|||
t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
|
|||
// The parent's children should contain the two cancelable children.
|
|||
pc := parent.(*cancelCtx) |
|||
cc := cancelChild.(*cancelCtx) |
|||
tc := timerChild.(*timerCtx) |
|||
pc.mu.Lock() |
|||
if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { |
|||
t.Errorf("bad linkage: pc.children = %v, want %v and %v", |
|||
pc.children, cc, tc) |
|||
} |
|||
pc.mu.Unlock() |
|||
|
|||
if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { |
|||
t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) |
|||
} |
|||
if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { |
|||
t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) |
|||
} |
|||
|
|||
cancel() |
|||
|
|||
pc.mu.Lock() |
|||
if len(pc.children) != 0 { |
|||
t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) |
|||
} |
|||
pc.mu.Unlock() |
|||
|
|||
// parent and children should all be finished.
|
|||
check := func(ctx Context, name string) { |
|||
select { |
|||
case <-ctx.Done(): |
|||
default: |
|||
t.Errorf("<-%s.Done() blocked, but shouldn't have", name) |
|||
} |
|||
if e := ctx.Err(); e != Canceled { |
|||
t.Errorf("%s.Err() == %v want %v", name, e, Canceled) |
|||
} |
|||
} |
|||
check(parent, "parent") |
|||
check(cancelChild, "cancelChild") |
|||
check(valueChild, "valueChild") |
|||
check(timerChild, "timerChild") |
|||
|
|||
// WithCancel should return a canceled context on a canceled parent.
|
|||
precanceledChild := WithValue(parent, "key", "value") |
|||
select { |
|||
case <-precanceledChild.Done(): |
|||
default: |
|||
t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") |
|||
} |
|||
if e := precanceledChild.Err(); e != Canceled { |
|||
t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) |
|||
} |
|||
} |
|||
|
|||
func TestChildFinishesFirst(t *testing.T) { |
|||
cancelable, stop := WithCancel(Background()) |
|||
defer stop() |
|||
for _, parent := range []Context{Background(), cancelable} { |
|||
child, cancel := WithCancel(parent) |
|||
|
|||
select { |
|||
case x := <-parent.Done(): |
|||
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
|||
case x := <-child.Done(): |
|||
t.Errorf("<-child.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
|
|||
cc := child.(*cancelCtx) |
|||
pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
|
|||
if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { |
|||
t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) |
|||
} |
|||
|
|||
if pcok { |
|||
pc.mu.Lock() |
|||
if len(pc.children) != 1 || !pc.children[cc] { |
|||
t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) |
|||
} |
|||
pc.mu.Unlock() |
|||
} |
|||
|
|||
cancel() |
|||
|
|||
if pcok { |
|||
pc.mu.Lock() |
|||
if len(pc.children) != 0 { |
|||
t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) |
|||
} |
|||
pc.mu.Unlock() |
|||
} |
|||
|
|||
// child should be finished.
|
|||
select { |
|||
case <-child.Done(): |
|||
default: |
|||
t.Errorf("<-child.Done() blocked, but shouldn't have") |
|||
} |
|||
if e := child.Err(); e != Canceled { |
|||
t.Errorf("child.Err() == %v want %v", e, Canceled) |
|||
} |
|||
|
|||
// parent should not be finished.
|
|||
select { |
|||
case x := <-parent.Done(): |
|||
t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) |
|||
default: |
|||
} |
|||
if e := parent.Err(); e != nil { |
|||
t.Errorf("parent.Err() == %v want nil", e) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func testDeadline(c Context, wait time.Duration, t *testing.T) { |
|||
select { |
|||
case <-time.After(wait): |
|||
t.Fatalf("context should have timed out") |
|||
case <-c.Done(): |
|||
} |
|||
if e := c.Err(); e != DeadlineExceeded { |
|||
t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) |
|||
} |
|||
} |
|||
|
|||
func TestDeadline(t *testing.T) { |
|||
c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
|||
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { |
|||
t.Errorf("c.String() = %q want prefix %q", got, prefix) |
|||
} |
|||
testDeadline(c, 200*time.Millisecond, t) |
|||
|
|||
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
|||
o := otherContext{c} |
|||
testDeadline(o, 200*time.Millisecond, t) |
|||
|
|||
c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) |
|||
o = otherContext{c} |
|||
c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond)) |
|||
testDeadline(c, 200*time.Millisecond, t) |
|||
} |
|||
|
|||
func TestTimeout(t *testing.T) { |
|||
c, _ := WithTimeout(Background(), 100*time.Millisecond) |
|||
if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { |
|||
t.Errorf("c.String() = %q want prefix %q", got, prefix) |
|||
} |
|||
testDeadline(c, 200*time.Millisecond, t) |
|||
|
|||
c, _ = WithTimeout(Background(), 100*time.Millisecond) |
|||
o := otherContext{c} |
|||
testDeadline(o, 200*time.Millisecond, t) |
|||
|
|||
c, _ = WithTimeout(Background(), 100*time.Millisecond) |
|||
o = otherContext{c} |
|||
c, _ = WithTimeout(o, 300*time.Millisecond) |
|||
testDeadline(c, 200*time.Millisecond, t) |
|||
} |
|||
|
|||
func TestCanceledTimeout(t *testing.T) { |
|||
c, _ := WithTimeout(Background(), 200*time.Millisecond) |
|||
o := otherContext{c} |
|||
c, cancel := WithTimeout(o, 400*time.Millisecond) |
|||
cancel() |
|||
time.Sleep(100 * time.Millisecond) // let cancelation propagate
|
|||
select { |
|||
case <-c.Done(): |
|||
default: |
|||
t.Errorf("<-c.Done() blocked, but shouldn't have") |
|||
} |
|||
if e := c.Err(); e != Canceled { |
|||
t.Errorf("c.Err() == %v want %v", e, Canceled) |
|||
} |
|||
} |
|||
|
|||
type key1 int |
|||
type key2 int |
|||
|
|||
var k1 = key1(1) |
|||
var k2 = key2(1) // same int as k1, different type
|
|||
var k3 = key2(3) // same type as k2, different int
|
|||
|
|||
func TestValues(t *testing.T) { |
|||
check := func(c Context, nm, v1, v2, v3 string) { |
|||
if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { |
|||
t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) |
|||
} |
|||
if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { |
|||
t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) |
|||
} |
|||
if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { |
|||
t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) |
|||
} |
|||
} |
|||
|
|||
c0 := Background() |
|||
check(c0, "c0", "", "", "") |
|||
|
|||
c1 := WithValue(Background(), k1, "c1k1") |
|||
check(c1, "c1", "c1k1", "", "") |
|||
|
|||
if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { |
|||
t.Errorf("c.String() = %q want %q", got, want) |
|||
} |
|||
|
|||
c2 := WithValue(c1, k2, "c2k2") |
|||
check(c2, "c2", "c1k1", "c2k2", "") |
|||
|
|||
c3 := WithValue(c2, k3, "c3k3") |
|||
check(c3, "c2", "c1k1", "c2k2", "c3k3") |
|||
|
|||
c4 := WithValue(c3, k1, nil) |
|||
check(c4, "c4", "", "c2k2", "c3k3") |
|||
|
|||
o0 := otherContext{Background()} |
|||
check(o0, "o0", "", "", "") |
|||
|
|||
o1 := otherContext{WithValue(Background(), k1, "c1k1")} |
|||
check(o1, "o1", "c1k1", "", "") |
|||
|
|||
o2 := WithValue(o1, k2, "o2k2") |
|||
check(o2, "o2", "c1k1", "o2k2", "") |
|||
|
|||
o3 := otherContext{c4} |
|||
check(o3, "o3", "", "c2k2", "c3k3") |
|||
|
|||
o4 := WithValue(o3, k3, nil) |
|||
check(o4, "o4", "", "c2k2", "") |
|||
} |
|||
|
|||
func TestAllocs(t *testing.T) { |
|||
bg := Background() |
|||
for _, test := range []struct { |
|||
desc string |
|||
f func() |
|||
limit float64 |
|||
gccgoLimit float64 |
|||
}{ |
|||
{ |
|||
desc: "Background()", |
|||
f: func() { Background() }, |
|||
limit: 0, |
|||
gccgoLimit: 0, |
|||
}, |
|||
{ |
|||
desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), |
|||
f: func() { |
|||
c := WithValue(bg, k1, nil) |
|||
c.Value(k1) |
|||
}, |
|||
limit: 3, |
|||
gccgoLimit: 3, |
|||
}, |
|||
{ |
|||
desc: "WithTimeout(bg, 15*time.Millisecond)", |
|||
f: func() { |
|||
c, _ := WithTimeout(bg, 15*time.Millisecond) |
|||
<-c.Done() |
|||
}, |
|||
limit: 8, |
|||
gccgoLimit: 16, |
|||
}, |
|||
{ |
|||
desc: "WithCancel(bg)", |
|||
f: func() { |
|||
c, cancel := WithCancel(bg) |
|||
cancel() |
|||
<-c.Done() |
|||
}, |
|||
limit: 5, |
|||
gccgoLimit: 8, |
|||
}, |
|||
{ |
|||
desc: "WithTimeout(bg, 100*time.Millisecond)", |
|||
f: func() { |
|||
c, cancel := WithTimeout(bg, 100*time.Millisecond) |
|||
cancel() |
|||
<-c.Done() |
|||
}, |
|||
limit: 8, |
|||
gccgoLimit: 25, |
|||
}, |
|||
} { |
|||
limit := test.limit |
|||
if runtime.Compiler == "gccgo" { |
|||
// gccgo does not yet do escape analysis.
|
|||
// TODO(iant): Remove this when gccgo does do escape analysis.
|
|||
limit = test.gccgoLimit |
|||
} |
|||
if n := testing.AllocsPerRun(100, test.f); n > limit { |
|||
t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestSimultaneousCancels(t *testing.T) { |
|||
root, cancel := WithCancel(Background()) |
|||
m := map[Context]CancelFunc{root: cancel} |
|||
q := []Context{root} |
|||
// Create a tree of contexts.
|
|||
for len(q) != 0 && len(m) < 100 { |
|||
parent := q[0] |
|||
q = q[1:] |
|||
for i := 0; i < 4; i++ { |
|||
ctx, cancel := WithCancel(parent) |
|||
m[ctx] = cancel |
|||
q = append(q, ctx) |
|||
} |
|||
} |
|||
// Start all the cancels in a random order.
|
|||
var wg sync.WaitGroup |
|||
wg.Add(len(m)) |
|||
for _, cancel := range m { |
|||
go func(cancel CancelFunc) { |
|||
cancel() |
|||
wg.Done() |
|||
}(cancel) |
|||
} |
|||
// Wait on all the contexts in a random order.
|
|||
for ctx := range m { |
|||
select { |
|||
case <-ctx.Done(): |
|||
case <-time.After(1 * time.Second): |
|||
buf := make([]byte, 10<<10) |
|||
n := runtime.Stack(buf, true) |
|||
t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) |
|||
} |
|||
} |
|||
// Wait for all the cancel functions to return.
|
|||
done := make(chan struct{}) |
|||
go func() { |
|||
wg.Wait() |
|||
close(done) |
|||
}() |
|||
select { |
|||
case <-done: |
|||
case <-time.After(1 * time.Second): |
|||
buf := make([]byte, 10<<10) |
|||
n := runtime.Stack(buf, true) |
|||
t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) |
|||
} |
|||
} |
|||
|
|||
func TestInterlockedCancels(t *testing.T) { |
|||
parent, cancelParent := WithCancel(Background()) |
|||
child, cancelChild := WithCancel(parent) |
|||
go func() { |
|||
parent.Done() |
|||
cancelChild() |
|||
}() |
|||
cancelParent() |
|||
select { |
|||
case <-child.Done(): |
|||
case <-time.After(1 * time.Second): |
|||
buf := make([]byte, 10<<10) |
|||
n := runtime.Stack(buf, true) |
|||
t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) |
|||
} |
|||
} |
|||
|
|||
func TestLayersCancel(t *testing.T) { |
|||
testLayers(t, time.Now().UnixNano(), false) |
|||
} |
|||
|
|||
func TestLayersTimeout(t *testing.T) { |
|||
testLayers(t, time.Now().UnixNano(), true) |
|||
} |
|||
|
|||
func testLayers(t *testing.T, seed int64, testTimeout bool) { |
|||
rand.Seed(seed) |
|||
errorf := func(format string, a ...interface{}) { |
|||
t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) |
|||
} |
|||
const ( |
|||
timeout = 200 * time.Millisecond |
|||
minLayers = 30 |
|||
) |
|||
type value int |
|||
var ( |
|||
vals []*value |
|||
cancels []CancelFunc |
|||
numTimers int |
|||
ctx = Background() |
|||
) |
|||
for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { |
|||
switch rand.Intn(3) { |
|||
case 0: |
|||
v := new(value) |
|||
ctx = WithValue(ctx, v, v) |
|||
vals = append(vals, v) |
|||
case 1: |
|||
var cancel CancelFunc |
|||
ctx, cancel = WithCancel(ctx) |
|||
cancels = append(cancels, cancel) |
|||
case 2: |
|||
var cancel CancelFunc |
|||
ctx, cancel = WithTimeout(ctx, timeout) |
|||
cancels = append(cancels, cancel) |
|||
numTimers++ |
|||
} |
|||
} |
|||
checkValues := func(when string) { |
|||
for _, key := range vals { |
|||
if val := ctx.Value(key).(*value); key != val { |
|||
errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) |
|||
} |
|||
} |
|||
} |
|||
select { |
|||
case <-ctx.Done(): |
|||
errorf("ctx should not be canceled yet") |
|||
default: |
|||
} |
|||
if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { |
|||
t.Errorf("ctx.String() = %q want prefix %q", s, prefix) |
|||
} |
|||
t.Log(ctx) |
|||
checkValues("before cancel") |
|||
if testTimeout { |
|||
select { |
|||
case <-ctx.Done(): |
|||
case <-time.After(timeout + 100*time.Millisecond): |
|||
errorf("ctx should have timed out") |
|||
} |
|||
checkValues("after timeout") |
|||
} else { |
|||
cancel := cancels[rand.Intn(len(cancels))] |
|||
cancel() |
|||
select { |
|||
case <-ctx.Done(): |
|||
default: |
|||
errorf("ctx should be canceled") |
|||
} |
|||
checkValues("after cancel") |
|||
} |
|||
} |
|||
|
|||
func TestCancelRemoves(t *testing.T) { |
|||
checkChildren := func(when string, ctx Context, want int) { |
|||
if got := len(ctx.(*cancelCtx).children); got != want { |
|||
t.Errorf("%s: context has %d children, want %d", when, got, want) |
|||
} |
|||
} |
|||
|
|||
ctx, _ := WithCancel(Background()) |
|||
checkChildren("after creation", ctx, 0) |
|||
_, cancel := WithCancel(ctx) |
|||
checkChildren("with WithCancel child ", ctx, 1) |
|||
cancel() |
|||
checkChildren("after cancelling WithCancel child", ctx, 0) |
|||
|
|||
ctx, _ = WithCancel(Background()) |
|||
checkChildren("after creation", ctx, 0) |
|||
_, cancel = WithTimeout(ctx, 60*time.Minute) |
|||
checkChildren("with WithTimeout child ", ctx, 1) |
|||
cancel() |
|||
checkChildren("after cancelling WithTimeout child", ctx, 0) |
|||
} |
@ -0,0 +1,74 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build go1.7
|
|||
|
|||
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
|
|||
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
|||
|
|||
import ( |
|||
"io" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
// Do sends an HTTP request with the provided http.Client and returns
|
|||
// an HTTP response.
|
|||
//
|
|||
// If the client is nil, http.DefaultClient is used.
|
|||
//
|
|||
// The provided ctx must be non-nil. If it is canceled or times out,
|
|||
// ctx.Err() will be returned.
|
|||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { |
|||
if client == nil { |
|||
client = http.DefaultClient |
|||
} |
|||
resp, err := client.Do(req.WithContext(ctx)) |
|||
// If we got an error, and the context has been canceled,
|
|||
// the context's error is probably more useful.
|
|||
if err != nil { |
|||
select { |
|||
case <-ctx.Done(): |
|||
err = ctx.Err() |
|||
default: |
|||
} |
|||
} |
|||
return resp, err |
|||
} |
|||
|
|||
// Get issues a GET request via the Do function.
|
|||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
|||
req, err := http.NewRequest("GET", url, nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// Head issues a HEAD request via the Do function.
|
|||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
|||
req, err := http.NewRequest("HEAD", url, nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// Post issues a POST request via the Do function.
|
|||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { |
|||
req, err := http.NewRequest("POST", url, body) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
req.Header.Set("Content-Type", bodyType) |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// PostForm issues a POST request via the Do function.
|
|||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { |
|||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) |
|||
} |
@ -0,0 +1,28 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !plan9,go1.7
|
|||
|
|||
package ctxhttp |
|||
|
|||
import ( |
|||
"io" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
|
|||
"context" |
|||
) |
|||
|
|||
func TestGo17Context(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
io.WriteString(w, "ok") |
|||
})) |
|||
ctx := context.Background() |
|||
resp, err := Get(ctx, http.DefaultClient, ts.URL) |
|||
if resp == nil || err != nil { |
|||
t.Fatalf("error received from client: %v %v", err, resp) |
|||
} |
|||
resp.Body.Close() |
|||
} |
@ -0,0 +1,147 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !go1.7
|
|||
|
|||
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
|
|||
|
|||
import ( |
|||
"io" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
func nop() {} |
|||
|
|||
var ( |
|||
testHookContextDoneBeforeHeaders = nop |
|||
testHookDoReturned = nop |
|||
testHookDidBodyClose = nop |
|||
) |
|||
|
|||
// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
|
|||
// If the client is nil, http.DefaultClient is used.
|
|||
// If the context is canceled or times out, ctx.Err() will be returned.
|
|||
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { |
|||
if client == nil { |
|||
client = http.DefaultClient |
|||
} |
|||
|
|||
// TODO(djd): Respect any existing value of req.Cancel.
|
|||
cancel := make(chan struct{}) |
|||
req.Cancel = cancel |
|||
|
|||
type responseAndError struct { |
|||
resp *http.Response |
|||
err error |
|||
} |
|||
result := make(chan responseAndError, 1) |
|||
|
|||
// Make local copies of test hooks closed over by goroutines below.
|
|||
// Prevents data races in tests.
|
|||
testHookDoReturned := testHookDoReturned |
|||
testHookDidBodyClose := testHookDidBodyClose |
|||
|
|||
go func() { |
|||
resp, err := client.Do(req) |
|||
testHookDoReturned() |
|||
result <- responseAndError{resp, err} |
|||
}() |
|||
|
|||
var resp *http.Response |
|||
|
|||
select { |
|||
case <-ctx.Done(): |
|||
testHookContextDoneBeforeHeaders() |
|||
close(cancel) |
|||
// Clean up after the goroutine calling client.Do:
|
|||
go func() { |
|||
if r := <-result; r.resp != nil { |
|||
testHookDidBodyClose() |
|||
r.resp.Body.Close() |
|||
} |
|||
}() |
|||
return nil, ctx.Err() |
|||
case r := <-result: |
|||
var err error |
|||
resp, err = r.resp, r.err |
|||
if err != nil { |
|||
return resp, err |
|||
} |
|||
} |
|||
|
|||
c := make(chan struct{}) |
|||
go func() { |
|||
select { |
|||
case <-ctx.Done(): |
|||
close(cancel) |
|||
case <-c: |
|||
// The response's Body is closed.
|
|||
} |
|||
}() |
|||
resp.Body = ¬ifyingReader{resp.Body, c} |
|||
|
|||
return resp, nil |
|||
} |
|||
|
|||
// Get issues a GET request via the Do function.
|
|||
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
|||
req, err := http.NewRequest("GET", url, nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// Head issues a HEAD request via the Do function.
|
|||
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { |
|||
req, err := http.NewRequest("HEAD", url, nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// Post issues a POST request via the Do function.
|
|||
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { |
|||
req, err := http.NewRequest("POST", url, body) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
req.Header.Set("Content-Type", bodyType) |
|||
return Do(ctx, client, req) |
|||
} |
|||
|
|||
// PostForm issues a POST request via the Do function.
|
|||
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { |
|||
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) |
|||
} |
|||
|
|||
// notifyingReader is an io.ReadCloser that closes the notify channel after
|
|||
// Close is called or a Read fails on the underlying ReadCloser.
|
|||
type notifyingReader struct { |
|||
io.ReadCloser |
|||
notify chan<- struct{} |
|||
} |
|||
|
|||
func (r *notifyingReader) Read(p []byte) (int, error) { |
|||
n, err := r.ReadCloser.Read(p) |
|||
if err != nil && r.notify != nil { |
|||
close(r.notify) |
|||
r.notify = nil |
|||
} |
|||
return n, err |
|||
} |
|||
|
|||
func (r *notifyingReader) Close() error { |
|||
err := r.ReadCloser.Close() |
|||
if r.notify != nil { |
|||
close(r.notify) |
|||
r.notify = nil |
|||
} |
|||
return err |
|||
} |
@ -0,0 +1,79 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !plan9,!go1.7
|
|||
|
|||
package ctxhttp |
|||
|
|||
import ( |
|||
"net" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"sync" |
|||
"testing" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
// golang.org/issue/14065
|
|||
func TestClosesResponseBodyOnCancel(t *testing.T) { |
|||
defer func() { testHookContextDoneBeforeHeaders = nop }() |
|||
defer func() { testHookDoReturned = nop }() |
|||
defer func() { testHookDidBodyClose = nop }() |
|||
|
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) |
|||
defer ts.Close() |
|||
|
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
|
|||
// closed when Do enters select case <-ctx.Done()
|
|||
enteredDonePath := make(chan struct{}) |
|||
|
|||
testHookContextDoneBeforeHeaders = func() { |
|||
close(enteredDonePath) |
|||
} |
|||
|
|||
testHookDoReturned = func() { |
|||
// We now have the result (the Flush'd headers) at least,
|
|||
// so we can cancel the request.
|
|||
cancel() |
|||
|
|||
// But block the client.Do goroutine from sending
|
|||
// until Do enters into the <-ctx.Done() path, since
|
|||
// otherwise if both channels are readable, select
|
|||
// picks a random one.
|
|||
<-enteredDonePath |
|||
} |
|||
|
|||
sawBodyClose := make(chan struct{}) |
|||
testHookDidBodyClose = func() { close(sawBodyClose) } |
|||
|
|||
tr := &http.Transport{} |
|||
defer tr.CloseIdleConnections() |
|||
c := &http.Client{Transport: tr} |
|||
req, _ := http.NewRequest("GET", ts.URL, nil) |
|||
_, doErr := Do(ctx, c, req) |
|||
|
|||
select { |
|||
case <-sawBodyClose: |
|||
case <-time.After(5 * time.Second): |
|||
t.Fatal("timeout waiting for body to close") |
|||
} |
|||
|
|||
if doErr != ctx.Err() { |
|||
t.Errorf("Do error = %v; want %v", doErr, ctx.Err()) |
|||
} |
|||
} |
|||
|
|||
type noteCloseConn struct { |
|||
net.Conn |
|||
onceClose sync.Once |
|||
closefn func() |
|||
} |
|||
|
|||
func (c *noteCloseConn) Close() error { |
|||
c.onceClose.Do(c.closefn) |
|||
return c.Conn.Close() |
|||
} |
@ -0,0 +1,105 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !plan9
|
|||
|
|||
package ctxhttp |
|||
|
|||
import ( |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
const ( |
|||
requestDuration = 100 * time.Millisecond |
|||
requestBody = "ok" |
|||
) |
|||
|
|||
func okHandler(w http.ResponseWriter, r *http.Request) { |
|||
time.Sleep(requestDuration) |
|||
io.WriteString(w, requestBody) |
|||
} |
|||
|
|||
func TestNoTimeout(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(okHandler)) |
|||
defer ts.Close() |
|||
|
|||
ctx := context.Background() |
|||
res, err := Get(ctx, nil, ts.URL) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
defer res.Body.Close() |
|||
slurp, err := ioutil.ReadAll(res.Body) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if string(slurp) != requestBody { |
|||
t.Errorf("body = %q; want %q", slurp, requestBody) |
|||
} |
|||
} |
|||
|
|||
func TestCancelBeforeHeaders(t *testing.T) { |
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
|
|||
blockServer := make(chan struct{}) |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
cancel() |
|||
<-blockServer |
|||
io.WriteString(w, requestBody) |
|||
})) |
|||
defer ts.Close() |
|||
defer close(blockServer) |
|||
|
|||
res, err := Get(ctx, nil, ts.URL) |
|||
if err == nil { |
|||
res.Body.Close() |
|||
t.Fatal("Get returned unexpected nil error") |
|||
} |
|||
if err != context.Canceled { |
|||
t.Errorf("err = %v; want %v", err, context.Canceled) |
|||
} |
|||
} |
|||
|
|||
func TestCancelAfterHangingRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.WriteHeader(http.StatusOK) |
|||
w.(http.Flusher).Flush() |
|||
<-w.(http.CloseNotifier).CloseNotify() |
|||
})) |
|||
defer ts.Close() |
|||
|
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
resp, err := Get(ctx, nil, ts.URL) |
|||
if err != nil { |
|||
t.Fatalf("unexpected error in Get: %v", err) |
|||
} |
|||
|
|||
// Cancel befer reading the body.
|
|||
// Reading Request.Body should fail, since the request was
|
|||
// canceled before anything was written.
|
|||
cancel() |
|||
|
|||
done := make(chan struct{}) |
|||
|
|||
go func() { |
|||
b, err := ioutil.ReadAll(resp.Body) |
|||
if len(b) != 0 || err == nil { |
|||
t.Errorf(`Read got (%q, %v); want ("", error)`, b, err) |
|||
} |
|||
close(done) |
|||
}() |
|||
|
|||
select { |
|||
case <-time.After(1 * time.Second): |
|||
t.Errorf("Test timed out") |
|||
case <-done: |
|||
} |
|||
} |
@ -0,0 +1,72 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build go1.7
|
|||
|
|||
package context |
|||
|
|||
import ( |
|||
"context" // standard library's context, as of Go 1.7
|
|||
"time" |
|||
) |
|||
|
|||
var ( |
|||
todo = context.TODO() |
|||
background = context.Background() |
|||
) |
|||
|
|||
// Canceled is the error returned by Context.Err when the context is canceled.
|
|||
var Canceled = context.Canceled |
|||
|
|||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
|||
// deadline passes.
|
|||
var DeadlineExceeded = context.DeadlineExceeded |
|||
|
|||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
|||
// context's Done channel is closed when the returned cancel function is called
|
|||
// or when the parent context's Done channel is closed, whichever happens first.
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete.
|
|||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
|||
ctx, f := context.WithCancel(parent) |
|||
return ctx, CancelFunc(f) |
|||
} |
|||
|
|||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
|||
// to be no later than d. If the parent's deadline is already earlier than d,
|
|||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
|||
// context's Done channel is closed when the deadline expires, when the returned
|
|||
// cancel function is called, or when the parent context's Done channel is
|
|||
// closed, whichever happens first.
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete.
|
|||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
|||
ctx, f := context.WithDeadline(parent, deadline) |
|||
return ctx, CancelFunc(f) |
|||
} |
|||
|
|||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete:
|
|||
//
|
|||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
|||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
|||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
|||
// return slowOperation(ctx)
|
|||
// }
|
|||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
|||
return WithDeadline(parent, time.Now().Add(timeout)) |
|||
} |
|||
|
|||
// WithValue returns a copy of parent in which the value associated with key is
|
|||
// val.
|
|||
//
|
|||
// Use context Values only for request-scoped data that transits processes and
|
|||
// APIs, not for passing optional parameters to functions.
|
|||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
|||
return context.WithValue(parent, key, val) |
|||
} |
@ -0,0 +1,300 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build !go1.7
|
|||
|
|||
package context |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
|
|||
// struct{}, since vars of this type must have distinct addresses.
|
|||
type emptyCtx int |
|||
|
|||
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { |
|||
return |
|||
} |
|||
|
|||
func (*emptyCtx) Done() <-chan struct{} { |
|||
return nil |
|||
} |
|||
|
|||
func (*emptyCtx) Err() error { |
|||
return nil |
|||
} |
|||
|
|||
func (*emptyCtx) Value(key interface{}) interface{} { |
|||
return nil |
|||
} |
|||
|
|||
func (e *emptyCtx) String() string { |
|||
switch e { |
|||
case background: |
|||
return "context.Background" |
|||
case todo: |
|||
return "context.TODO" |
|||
} |
|||
return "unknown empty Context" |
|||
} |
|||
|
|||
var ( |
|||
background = new(emptyCtx) |
|||
todo = new(emptyCtx) |
|||
) |
|||
|
|||
// Canceled is the error returned by Context.Err when the context is canceled.
|
|||
var Canceled = errors.New("context canceled") |
|||
|
|||
// DeadlineExceeded is the error returned by Context.Err when the context's
|
|||
// deadline passes.
|
|||
var DeadlineExceeded = errors.New("context deadline exceeded") |
|||
|
|||
// WithCancel returns a copy of parent with a new Done channel. The returned
|
|||
// context's Done channel is closed when the returned cancel function is called
|
|||
// or when the parent context's Done channel is closed, whichever happens first.
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete.
|
|||
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { |
|||
c := newCancelCtx(parent) |
|||
propagateCancel(parent, c) |
|||
return c, func() { c.cancel(true, Canceled) } |
|||
} |
|||
|
|||
// newCancelCtx returns an initialized cancelCtx.
|
|||
func newCancelCtx(parent Context) *cancelCtx { |
|||
return &cancelCtx{ |
|||
Context: parent, |
|||
done: make(chan struct{}), |
|||
} |
|||
} |
|||
|
|||
// propagateCancel arranges for child to be canceled when parent is.
|
|||
func propagateCancel(parent Context, child canceler) { |
|||
if parent.Done() == nil { |
|||
return // parent is never canceled
|
|||
} |
|||
if p, ok := parentCancelCtx(parent); ok { |
|||
p.mu.Lock() |
|||
if p.err != nil { |
|||
// parent has already been canceled
|
|||
child.cancel(false, p.err) |
|||
} else { |
|||
if p.children == nil { |
|||
p.children = make(map[canceler]bool) |
|||
} |
|||
p.children[child] = true |
|||
} |
|||
p.mu.Unlock() |
|||
} else { |
|||
go func() { |
|||
select { |
|||
case <-parent.Done(): |
|||
child.cancel(false, parent.Err()) |
|||
case <-child.Done(): |
|||
} |
|||
}() |
|||
} |
|||
} |
|||
|
|||
// parentCancelCtx follows a chain of parent references until it finds a
|
|||
// *cancelCtx. This function understands how each of the concrete types in this
|
|||
// package represents its parent.
|
|||
func parentCancelCtx(parent Context) (*cancelCtx, bool) { |
|||
for { |
|||
switch c := parent.(type) { |
|||
case *cancelCtx: |
|||
return c, true |
|||
case *timerCtx: |
|||
return c.cancelCtx, true |
|||
case *valueCtx: |
|||
parent = c.Context |
|||
default: |
|||
return nil, false |
|||
} |
|||
} |
|||
} |
|||
|
|||
// removeChild removes a context from its parent.
|
|||
func removeChild(parent Context, child canceler) { |
|||
p, ok := parentCancelCtx(parent) |
|||
if !ok { |
|||
return |
|||
} |
|||
p.mu.Lock() |
|||
if p.children != nil { |
|||
delete(p.children, child) |
|||
} |
|||
p.mu.Unlock() |
|||
} |
|||
|
|||
// A canceler is a context type that can be canceled directly. The
|
|||
// implementations are *cancelCtx and *timerCtx.
|
|||
type canceler interface { |
|||
cancel(removeFromParent bool, err error) |
|||
Done() <-chan struct{} |
|||
} |
|||
|
|||
// A cancelCtx can be canceled. When canceled, it also cancels any children
|
|||
// that implement canceler.
|
|||
type cancelCtx struct { |
|||
Context |
|||
|
|||
done chan struct{} // closed by the first cancel call.
|
|||
|
|||
mu sync.Mutex |
|||
children map[canceler]bool // set to nil by the first cancel call
|
|||
err error // set to non-nil by the first cancel call
|
|||
} |
|||
|
|||
func (c *cancelCtx) Done() <-chan struct{} { |
|||
return c.done |
|||
} |
|||
|
|||
func (c *cancelCtx) Err() error { |
|||
c.mu.Lock() |
|||
defer c.mu.Unlock() |
|||
return c.err |
|||
} |
|||
|
|||
func (c *cancelCtx) String() string { |
|||
return fmt.Sprintf("%v.WithCancel", c.Context) |
|||
} |
|||
|
|||
// cancel closes c.done, cancels each of c's children, and, if
|
|||
// removeFromParent is true, removes c from its parent's children.
|
|||
func (c *cancelCtx) cancel(removeFromParent bool, err error) { |
|||
if err == nil { |
|||
panic("context: internal error: missing cancel error") |
|||
} |
|||
c.mu.Lock() |
|||
if c.err != nil { |
|||
c.mu.Unlock() |
|||
return // already canceled
|
|||
} |
|||
c.err = err |
|||
close(c.done) |
|||
for child := range c.children { |
|||
// NOTE: acquiring the child's lock while holding parent's lock.
|
|||
child.cancel(false, err) |
|||
} |
|||
c.children = nil |
|||
c.mu.Unlock() |
|||
|
|||
if removeFromParent { |
|||
removeChild(c.Context, c) |
|||
} |
|||
} |
|||
|
|||
// WithDeadline returns a copy of the parent context with the deadline adjusted
|
|||
// to be no later than d. If the parent's deadline is already earlier than d,
|
|||
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
|
|||
// context's Done channel is closed when the deadline expires, when the returned
|
|||
// cancel function is called, or when the parent context's Done channel is
|
|||
// closed, whichever happens first.
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete.
|
|||
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { |
|||
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { |
|||
// The current deadline is already sooner than the new one.
|
|||
return WithCancel(parent) |
|||
} |
|||
c := &timerCtx{ |
|||
cancelCtx: newCancelCtx(parent), |
|||
deadline: deadline, |
|||
} |
|||
propagateCancel(parent, c) |
|||
d := deadline.Sub(time.Now()) |
|||
if d <= 0 { |
|||
c.cancel(true, DeadlineExceeded) // deadline has already passed
|
|||
return c, func() { c.cancel(true, Canceled) } |
|||
} |
|||
c.mu.Lock() |
|||
defer c.mu.Unlock() |
|||
if c.err == nil { |
|||
c.timer = time.AfterFunc(d, func() { |
|||
c.cancel(true, DeadlineExceeded) |
|||
}) |
|||
} |
|||
return c, func() { c.cancel(true, Canceled) } |
|||
} |
|||
|
|||
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
|
|||
// implement Done and Err. It implements cancel by stopping its timer then
|
|||
// delegating to cancelCtx.cancel.
|
|||
type timerCtx struct { |
|||
*cancelCtx |
|||
timer *time.Timer // Under cancelCtx.mu.
|
|||
|
|||
deadline time.Time |
|||
} |
|||
|
|||
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { |
|||
return c.deadline, true |
|||
} |
|||
|
|||
func (c *timerCtx) String() string { |
|||
return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) |
|||
} |
|||
|
|||
func (c *timerCtx) cancel(removeFromParent bool, err error) { |
|||
c.cancelCtx.cancel(false, err) |
|||
if removeFromParent { |
|||
// Remove this timerCtx from its parent cancelCtx's children.
|
|||
removeChild(c.cancelCtx.Context, c) |
|||
} |
|||
c.mu.Lock() |
|||
if c.timer != nil { |
|||
c.timer.Stop() |
|||
c.timer = nil |
|||
} |
|||
c.mu.Unlock() |
|||
} |
|||
|
|||
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
|
|||
//
|
|||
// Canceling this context releases resources associated with it, so code should
|
|||
// call cancel as soon as the operations running in this Context complete:
|
|||
//
|
|||
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
|
|||
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
|||
// defer cancel() // releases resources if slowOperation completes before timeout elapses
|
|||
// return slowOperation(ctx)
|
|||
// }
|
|||
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { |
|||
return WithDeadline(parent, time.Now().Add(timeout)) |
|||
} |
|||
|
|||
// WithValue returns a copy of parent in which the value associated with key is
|
|||
// val.
|
|||
//
|
|||
// Use context Values only for request-scoped data that transits processes and
|
|||
// APIs, not for passing optional parameters to functions.
|
|||
func WithValue(parent Context, key interface{}, val interface{}) Context { |
|||
return &valueCtx{parent, key, val} |
|||
} |
|||
|
|||
// A valueCtx carries a key-value pair. It implements Value for that key and
|
|||
// delegates all other calls to the embedded Context.
|
|||
type valueCtx struct { |
|||
Context |
|||
key, val interface{} |
|||
} |
|||
|
|||
func (c *valueCtx) String() string { |
|||
return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) |
|||
} |
|||
|
|||
func (c *valueCtx) Value(key interface{}) interface{} { |
|||
if c.key == key { |
|||
return c.val |
|||
} |
|||
return c.Context.Value(key) |
|||
} |
@ -0,0 +1,26 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package context_test |
|||
|
|||
import ( |
|||
"fmt" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
func ExampleWithTimeout() { |
|||
// Pass a context with a timeout to tell a blocking function that it
|
|||
// should abandon its work after the timeout elapses.
|
|||
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) |
|||
select { |
|||
case <-time.After(200 * time.Millisecond): |
|||
fmt.Println("overslept") |
|||
case <-ctx.Done(): |
|||
fmt.Println(ctx.Err()) // prints "context deadline exceeded"
|
|||
} |
|||
// Output:
|
|||
// context deadline exceeded
|
|||
} |
@ -0,0 +1,3 @@ |
|||
# This source code refers to The Go Authors for copyright purposes. |
|||
# The master list of authors is in the main Go distribution, |
|||
# visible at http://tip.golang.org/AUTHORS. |
@ -0,0 +1,31 @@ |
|||
# Contributing to Go |
|||
|
|||
Go is an open source project. |
|||
|
|||
It is the work of hundreds of contributors. We appreciate your help! |
|||
|
|||
|
|||
## Filing issues |
|||
|
|||
When [filing an issue](https://github.com/golang/oauth2/issues), make sure to answer these five questions: |
|||
|
|||
1. What version of Go are you using (`go version`)? |
|||
2. What operating system and processor architecture are you using? |
|||
3. What did you do? |
|||
4. What did you expect to see? |
|||
5. What did you see instead? |
|||
|
|||
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. |
|||
The gophers there will answer or ask you to file an issue if you've tripped over a bug. |
|||
|
|||
## Contributing code |
|||
|
|||
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) |
|||
before sending patches. |
|||
|
|||
**We do not accept GitHub pull requests** |
|||
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review). |
|||
|
|||
Unless otherwise noted, the Go source files are distributed under |
|||
the BSD-style license found in the LICENSE file. |
|||
|
@ -0,0 +1,3 @@ |
|||
# This source code was written by the Go contributors. |
|||
# The master list of contributors is in the main Go distribution, |
|||
# visible at http://tip.golang.org/CONTRIBUTORS. |
@ -0,0 +1,27 @@ |
|||
Copyright (c) 2009 The oauth2 Authors. All rights reserved. |
|||
|
|||
Redistribution and use in source and binary forms, with or without |
|||
modification, are permitted provided that the following conditions are |
|||
met: |
|||
|
|||
* Redistributions of source code must retain the above copyright |
|||
notice, this list of conditions and the following disclaimer. |
|||
* Redistributions in binary form must reproduce the above |
|||
copyright notice, this list of conditions and the following disclaimer |
|||
in the documentation and/or other materials provided with the |
|||
distribution. |
|||
* Neither the name of Google Inc. nor the names of its |
|||
contributors may be used to endorse or promote products derived from |
|||
this software without specific prior written permission. |
|||
|
|||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
|||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
|||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
|||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
|||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
|||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
|||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
|||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
|||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@ -0,0 +1,64 @@ |
|||
# OAuth2 for Go |
|||
|
|||
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2) |
|||
|
|||
oauth2 package contains a client implementation for OAuth 2.0 spec. |
|||
|
|||
## Installation |
|||
|
|||
~~~~ |
|||
go get golang.org/x/oauth2 |
|||
~~~~ |
|||
|
|||
See godoc for further documentation and examples. |
|||
|
|||
* [godoc.org/golang.org/x/oauth2](http://godoc.org/golang.org/x/oauth2) |
|||
* [godoc.org/golang.org/x/oauth2/google](http://godoc.org/golang.org/x/oauth2/google) |
|||
|
|||
|
|||
## App Engine |
|||
|
|||
In change 96e89be (March 2015) we removed the `oauth2.Context2` type in favor |
|||
of the [`context.Context`](https://golang.org/x/net/context#Context) type from |
|||
the `golang.org/x/net/context` package |
|||
|
|||
This means its no longer possible to use the "Classic App Engine" |
|||
`appengine.Context` type with the `oauth2` package. (You're using |
|||
Classic App Engine if you import the package `"appengine"`.) |
|||
|
|||
To work around this, you may use the new `"google.golang.org/appengine"` |
|||
package. This package has almost the same API as the `"appengine"` package, |
|||
but it can be fetched with `go get` and used on "Managed VMs" and well as |
|||
Classic App Engine. |
|||
|
|||
See the [new `appengine` package's readme](https://github.com/golang/appengine#updating-a-go-app-engine-app) |
|||
for information on updating your app. |
|||
|
|||
If you don't want to update your entire app to use the new App Engine packages, |
|||
you may use both sets of packages in parallel, using only the new packages |
|||
with the `oauth2` package. |
|||
|
|||
import ( |
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/google" |
|||
newappengine "google.golang.org/appengine" |
|||
newurlfetch "google.golang.org/appengine/urlfetch" |
|||
|
|||
"appengine" |
|||
) |
|||
|
|||
func handler(w http.ResponseWriter, r *http.Request) { |
|||
var c appengine.Context = appengine.NewContext(r) |
|||
c.Infof("Logging a message with the old package") |
|||
|
|||
var ctx context.Context = newappengine.NewContext(r) |
|||
client := &http.Client{ |
|||
Transport: &oauth2.Transport{ |
|||
Source: google.AppEngineTokenSource(ctx, "scope"), |
|||
Base: &newurlfetch.Transport{Context: ctx}, |
|||
}, |
|||
} |
|||
client.Get("...") |
|||
} |
|||
|
@ -0,0 +1,16 @@ |
|||
// Copyright 2015 The oauth2 Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package bitbucket provides constants for using OAuth2 to access Bitbucket.
|
|||
package bitbucket |
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is Bitbucket's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://bitbucket.org/site/oauth2/authorize", |
|||
TokenURL: "https://bitbucket.org/site/oauth2/access_token", |
|||
} |
@ -0,0 +1,25 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build appengine
|
|||
|
|||
// App Engine hooks.
|
|||
|
|||
package oauth2 |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2/internal" |
|||
"google.golang.org/appengine/urlfetch" |
|||
) |
|||
|
|||
func init() { |
|||
internal.RegisterContextClientFunc(contextClientAppEngine) |
|||
} |
|||
|
|||
func contextClientAppEngine(ctx context.Context) (*http.Client, error) { |
|||
return urlfetch.Client(ctx), nil |
|||
} |
@ -0,0 +1,112 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package clientcredentials implements the OAuth2.0 "client credentials" token flow,
|
|||
// also known as the "two-legged OAuth 2.0".
|
|||
//
|
|||
// This should be used when the client is acting on its own behalf or when the client
|
|||
// is the resource owner. It may also be used when requesting access to protected
|
|||
// resources based on an authorization previously arranged with the authorization
|
|||
// server.
|
|||
//
|
|||
// See http://tools.ietf.org/html/draft-ietf-oauth-v2-31#section-4.4
|
|||
package clientcredentials // import "golang.org/x/oauth2/clientcredentials"
|
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/internal" |
|||
) |
|||
|
|||
// tokenFromInternal maps an *internal.Token struct into
|
|||
// an *oauth2.Token struct.
|
|||
func tokenFromInternal(t *internal.Token) *oauth2.Token { |
|||
if t == nil { |
|||
return nil |
|||
} |
|||
tk := &oauth2.Token{ |
|||
AccessToken: t.AccessToken, |
|||
TokenType: t.TokenType, |
|||
RefreshToken: t.RefreshToken, |
|||
Expiry: t.Expiry, |
|||
} |
|||
return tk.WithExtra(t.Raw) |
|||
} |
|||
|
|||
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
|||
// This token is then mapped from *internal.Token into an *oauth2.Token which is
|
|||
// returned along with an error.
|
|||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*oauth2.Token, error) { |
|||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.TokenURL, v) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return tokenFromInternal(tk), nil |
|||
} |
|||
|
|||
// Client Credentials Config describes a 2-legged OAuth2 flow, with both the
|
|||
// client application information and the server's endpoint URLs.
|
|||
type Config struct { |
|||
// ClientID is the application's ID.
|
|||
ClientID string |
|||
|
|||
// ClientSecret is the application's secret.
|
|||
ClientSecret string |
|||
|
|||
// TokenURL is the resource server's token endpoint
|
|||
// URL. This is a constant specific to each server.
|
|||
TokenURL string |
|||
|
|||
// Scope specifies optional requested permissions.
|
|||
Scopes []string |
|||
} |
|||
|
|||
// Token uses client credentials to retrieve a token.
|
|||
// The HTTP client to use is derived from the context.
|
|||
// If nil, http.DefaultClient is used.
|
|||
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { |
|||
return retrieveToken(ctx, c, url.Values{ |
|||
"grant_type": {"client_credentials"}, |
|||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
|||
}) |
|||
} |
|||
|
|||
// Client returns an HTTP client using the provided token.
|
|||
// The token will auto-refresh as necessary. The underlying
|
|||
// HTTP transport will be obtained using the provided context.
|
|||
// The returned client and its Transport should not be modified.
|
|||
func (c *Config) Client(ctx context.Context) *http.Client { |
|||
return oauth2.NewClient(ctx, c.TokenSource(ctx)) |
|||
} |
|||
|
|||
// TokenSource returns a TokenSource that returns t until t expires,
|
|||
// automatically refreshing it as necessary using the provided context and the
|
|||
// client ID and client secret.
|
|||
//
|
|||
// Most users will use Config.Client instead.
|
|||
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { |
|||
source := &tokenSource{ |
|||
ctx: ctx, |
|||
conf: c, |
|||
} |
|||
return oauth2.ReuseTokenSource(nil, source) |
|||
} |
|||
|
|||
type tokenSource struct { |
|||
ctx context.Context |
|||
conf *Config |
|||
} |
|||
|
|||
// Token refreshes the token by using a new client credentials request.
|
|||
// tokens received this way do not include a refresh token
|
|||
func (c *tokenSource) Token() (*oauth2.Token, error) { |
|||
return retrieveToken(c.ctx, c.conf, url.Values{ |
|||
"grant_type": {"client_credentials"}, |
|||
"scope": internal.CondVal(strings.Join(c.conf.Scopes, " ")), |
|||
}) |
|||
} |
@ -0,0 +1,96 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package clientcredentials |
|||
|
|||
import ( |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
|
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
func newConf(url string) *Config { |
|||
return &Config{ |
|||
ClientID: "CLIENT_ID", |
|||
ClientSecret: "CLIENT_SECRET", |
|||
Scopes: []string{"scope1", "scope2"}, |
|||
TokenURL: url + "/token", |
|||
} |
|||
} |
|||
|
|||
type mockTransport struct { |
|||
rt func(req *http.Request) (resp *http.Response, err error) |
|||
} |
|||
|
|||
func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { |
|||
return t.rt(req) |
|||
} |
|||
|
|||
func TestTokenRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token") |
|||
} |
|||
headerAuth := r.Header.Get("Authorization") |
|||
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
|||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
|||
} |
|||
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { |
|||
t.Errorf("Content-Type header = %q; want %q", got, want) |
|||
} |
|||
body, err := ioutil.ReadAll(r.Body) |
|||
if err != nil { |
|||
r.Body.Close() |
|||
} |
|||
if err != nil { |
|||
t.Errorf("failed reading request body: %s.", err) |
|||
} |
|||
if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { |
|||
t.Errorf("payload = %q; want %q", string(body), "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2") |
|||
} |
|||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
|||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tok, err := conf.Token(oauth2.NoContext) |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if !tok.Valid() { |
|||
t.Fatalf("token invalid. got: %#v", tok) |
|||
} |
|||
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
|||
t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") |
|||
} |
|||
if tok.TokenType != "bearer" { |
|||
t.Errorf("token type = %q; want %q", tok.TokenType, "bearer") |
|||
} |
|||
} |
|||
|
|||
func TestTokenRefreshRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() == "/somethingelse" { |
|||
return |
|||
} |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
if headerContentType != "application/x-www-form-urlencoded" { |
|||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
|||
} |
|||
body, _ := ioutil.ReadAll(r.Body) |
|||
if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { |
|||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
|||
} |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
c := conf.Client(oauth2.NoContext) |
|||
c.Get(ts.URL + "/somethingelse") |
|||
} |
@ -0,0 +1,45 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package oauth2_test |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
|
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
func ExampleConfig() { |
|||
conf := &oauth2.Config{ |
|||
ClientID: "YOUR_CLIENT_ID", |
|||
ClientSecret: "YOUR_CLIENT_SECRET", |
|||
Scopes: []string{"SCOPE1", "SCOPE2"}, |
|||
Endpoint: oauth2.Endpoint{ |
|||
AuthURL: "https://provider.com/o/oauth2/auth", |
|||
TokenURL: "https://provider.com/o/oauth2/token", |
|||
}, |
|||
} |
|||
|
|||
// Redirect user to consent page to ask for permission
|
|||
// for the scopes specified above.
|
|||
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline) |
|||
fmt.Printf("Visit the URL for the auth dialog: %v", url) |
|||
|
|||
// Use the authorization code that is pushed to the redirect
|
|||
// URL. Exchange will do the handshake to retrieve the
|
|||
// initial access token. The HTTP Client returned by
|
|||
// conf.Client will refresh the token as necessary.
|
|||
var code string |
|||
if _, err := fmt.Scan(&code); err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
tok, err := conf.Exchange(oauth2.NoContext, code) |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
|
|||
client := conf.Client(oauth2.NoContext, tok) |
|||
client.Get("...") |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package facebook provides constants for using OAuth2 to access Facebook.
|
|||
package facebook // import "golang.org/x/oauth2/facebook"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is Facebook's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.facebook.com/dialog/oauth", |
|||
TokenURL: "https://graph.facebook.com/oauth/access_token", |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package fitbit provides constants for using OAuth2 to access the Fitbit API.
|
|||
package fitbit // import "golang.org/x/oauth2/fitbit"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is the Fitbit API's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.fitbit.com/oauth2/authorize", |
|||
TokenURL: "https://api.fitbit.com/oauth2/token", |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package github provides constants for using OAuth2 to access Github.
|
|||
package github // import "golang.org/x/oauth2/github"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is Github's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://github.com/login/oauth/authorize", |
|||
TokenURL: "https://github.com/login/oauth/access_token", |
|||
} |
@ -0,0 +1,86 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"sort" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Set at init time by appenginevm_hook.go. If true, we are on App Engine Managed VMs.
|
|||
var appengineVM bool |
|||
|
|||
// Set at init time by appengine_hook.go. If nil, we're not on App Engine.
|
|||
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error) |
|||
|
|||
// AppEngineTokenSource returns a token source that fetches tokens
|
|||
// issued to the current App Engine application's service account.
|
|||
// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
|
|||
// that involves user accounts, see oauth2.Config instead.
|
|||
//
|
|||
// The provided context must have come from appengine.NewContext.
|
|||
func AppEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource { |
|||
if appengineTokenFunc == nil { |
|||
panic("google: AppEngineTokenSource can only be used on App Engine.") |
|||
} |
|||
scopes := append([]string{}, scope...) |
|||
sort.Strings(scopes) |
|||
return &appEngineTokenSource{ |
|||
ctx: ctx, |
|||
scopes: scopes, |
|||
key: strings.Join(scopes, " "), |
|||
} |
|||
} |
|||
|
|||
// aeTokens helps the fetched tokens to be reused until their expiration.
|
|||
var ( |
|||
aeTokensMu sync.Mutex |
|||
aeTokens = make(map[string]*tokenLock) // key is space-separated scopes
|
|||
) |
|||
|
|||
type tokenLock struct { |
|||
mu sync.Mutex // guards t; held while fetching or updating t
|
|||
t *oauth2.Token |
|||
} |
|||
|
|||
type appEngineTokenSource struct { |
|||
ctx context.Context |
|||
scopes []string |
|||
key string // to aeTokens map; space-separated scopes
|
|||
} |
|||
|
|||
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { |
|||
if appengineTokenFunc == nil { |
|||
panic("google: AppEngineTokenSource can only be used on App Engine.") |
|||
} |
|||
|
|||
aeTokensMu.Lock() |
|||
tok, ok := aeTokens[ts.key] |
|||
if !ok { |
|||
tok = &tokenLock{} |
|||
aeTokens[ts.key] = tok |
|||
} |
|||
aeTokensMu.Unlock() |
|||
|
|||
tok.mu.Lock() |
|||
defer tok.mu.Unlock() |
|||
if tok.t.Valid() { |
|||
return tok.t, nil |
|||
} |
|||
access, exp, err := appengineTokenFunc(ts.ctx, ts.scopes...) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tok.t = &oauth2.Token{ |
|||
AccessToken: access, |
|||
Expiry: exp, |
|||
} |
|||
return tok.t, nil |
|||
} |
@ -0,0 +1,13 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build appengine
|
|||
|
|||
package google |
|||
|
|||
import "google.golang.org/appengine" |
|||
|
|||
func init() { |
|||
appengineTokenFunc = appengine.AccessToken |
|||
} |
@ -0,0 +1,14 @@ |
|||
// Copyright 2015 The oauth2 Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build appenginevm
|
|||
|
|||
package google |
|||
|
|||
import "google.golang.org/appengine" |
|||
|
|||
func init() { |
|||
appengineVM = true |
|||
appengineTokenFunc = appengine.AccessToken |
|||
} |
@ -0,0 +1,155 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"os" |
|||
"path/filepath" |
|||
"runtime" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/jwt" |
|||
"google.golang.org/cloud/compute/metadata" |
|||
) |
|||
|
|||
// DefaultClient returns an HTTP Client that uses the
|
|||
// DefaultTokenSource to obtain authentication credentials.
|
|||
//
|
|||
// This client should be used when developing services
|
|||
// that run on Google App Engine or Google Compute Engine
|
|||
// and use "Application Default Credentials."
|
|||
//
|
|||
// For more details, see:
|
|||
// https://developers.google.com/accounts/docs/application-default-credentials
|
|||
//
|
|||
func DefaultClient(ctx context.Context, scope ...string) (*http.Client, error) { |
|||
ts, err := DefaultTokenSource(ctx, scope...) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return oauth2.NewClient(ctx, ts), nil |
|||
} |
|||
|
|||
// DefaultTokenSource is a token source that uses
|
|||
// "Application Default Credentials".
|
|||
//
|
|||
// It looks for credentials in the following places,
|
|||
// preferring the first location found:
|
|||
//
|
|||
// 1. A JSON file whose path is specified by the
|
|||
// GOOGLE_APPLICATION_CREDENTIALS environment variable.
|
|||
// 2. A JSON file in a location known to the gcloud command-line tool.
|
|||
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
|
|||
// On other systems, $HOME/.config/gcloud/application_default_credentials.json.
|
|||
// 3. On Google App Engine it uses the appengine.AccessToken function.
|
|||
// 4. On Google Compute Engine and Google App Engine Managed VMs, it fetches
|
|||
// credentials from the metadata server.
|
|||
// (In this final case any provided scopes are ignored.)
|
|||
//
|
|||
// For more details, see:
|
|||
// https://developers.google.com/accounts/docs/application-default-credentials
|
|||
//
|
|||
func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSource, error) { |
|||
// First, try the environment variable.
|
|||
const envVar = "GOOGLE_APPLICATION_CREDENTIALS" |
|||
if filename := os.Getenv(envVar); filename != "" { |
|||
ts, err := tokenSourceFromFile(ctx, filename, scope) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("google: error getting credentials using %v environment variable: %v", envVar, err) |
|||
} |
|||
return ts, nil |
|||
} |
|||
|
|||
// Second, try a well-known file.
|
|||
filename := wellKnownFile() |
|||
_, err := os.Stat(filename) |
|||
if err == nil { |
|||
ts, err2 := tokenSourceFromFile(ctx, filename, scope) |
|||
if err2 == nil { |
|||
return ts, nil |
|||
} |
|||
err = err2 |
|||
} else if os.IsNotExist(err) { |
|||
err = nil // ignore this error
|
|||
} |
|||
if err != nil { |
|||
return nil, fmt.Errorf("google: error getting credentials using well-known file (%v): %v", filename, err) |
|||
} |
|||
|
|||
// Third, if we're on Google App Engine use those credentials.
|
|||
if appengineTokenFunc != nil && !appengineVM { |
|||
return AppEngineTokenSource(ctx, scope...), nil |
|||
} |
|||
|
|||
// Fourth, if we're on Google Compute Engine use the metadata server.
|
|||
if metadata.OnGCE() { |
|||
return ComputeTokenSource(""), nil |
|||
} |
|||
|
|||
// None are found; return helpful error.
|
|||
const url = "https://developers.google.com/accounts/docs/application-default-credentials" |
|||
return nil, fmt.Errorf("google: could not find default credentials. See %v for more information.", url) |
|||
} |
|||
|
|||
func wellKnownFile() string { |
|||
const f = "application_default_credentials.json" |
|||
if runtime.GOOS == "windows" { |
|||
return filepath.Join(os.Getenv("APPDATA"), "gcloud", f) |
|||
} |
|||
return filepath.Join(guessUnixHomeDir(), ".config", "gcloud", f) |
|||
} |
|||
|
|||
func tokenSourceFromFile(ctx context.Context, filename string, scopes []string) (oauth2.TokenSource, error) { |
|||
b, err := ioutil.ReadFile(filename) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
var d struct { |
|||
// Common fields
|
|||
Type string |
|||
ClientID string `json:"client_id"` |
|||
|
|||
// User Credential fields
|
|||
ClientSecret string `json:"client_secret"` |
|||
RefreshToken string `json:"refresh_token"` |
|||
|
|||
// Service Account fields
|
|||
ClientEmail string `json:"client_email"` |
|||
PrivateKeyID string `json:"private_key_id"` |
|||
PrivateKey string `json:"private_key"` |
|||
} |
|||
if err := json.Unmarshal(b, &d); err != nil { |
|||
return nil, err |
|||
} |
|||
switch d.Type { |
|||
case "authorized_user": |
|||
cfg := &oauth2.Config{ |
|||
ClientID: d.ClientID, |
|||
ClientSecret: d.ClientSecret, |
|||
Scopes: append([]string{}, scopes...), // copy
|
|||
Endpoint: Endpoint, |
|||
} |
|||
tok := &oauth2.Token{RefreshToken: d.RefreshToken} |
|||
return cfg.TokenSource(ctx, tok), nil |
|||
case "service_account": |
|||
cfg := &jwt.Config{ |
|||
Email: d.ClientEmail, |
|||
PrivateKey: []byte(d.PrivateKey), |
|||
Scopes: append([]string{}, scopes...), // copy
|
|||
TokenURL: JWTTokenURL, |
|||
} |
|||
return cfg.TokenSource(ctx), nil |
|||
case "": |
|||
return nil, errors.New("missing 'type' field in credentials") |
|||
default: |
|||
return nil, fmt.Errorf("unknown credential type: %q", d.Type) |
|||
} |
|||
} |
@ -0,0 +1,150 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// +build appenginevm appengine
|
|||
|
|||
package google_test |
|||
|
|||
import ( |
|||
"fmt" |
|||
"io/ioutil" |
|||
"log" |
|||
"net/http" |
|||
|
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/google" |
|||
"golang.org/x/oauth2/jwt" |
|||
"google.golang.org/appengine" |
|||
"google.golang.org/appengine/urlfetch" |
|||
) |
|||
|
|||
func ExampleDefaultClient() { |
|||
client, err := google.DefaultClient(oauth2.NoContext, |
|||
"https://www.googleapis.com/auth/devstorage.full_control") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
client.Get("...") |
|||
} |
|||
|
|||
func Example_webServer() { |
|||
// Your credentials should be obtained from the Google
|
|||
// Developer Console (https://console.developers.google.com).
|
|||
conf := &oauth2.Config{ |
|||
ClientID: "YOUR_CLIENT_ID", |
|||
ClientSecret: "YOUR_CLIENT_SECRET", |
|||
RedirectURL: "YOUR_REDIRECT_URL", |
|||
Scopes: []string{ |
|||
"https://www.googleapis.com/auth/bigquery", |
|||
"https://www.googleapis.com/auth/blogger", |
|||
}, |
|||
Endpoint: google.Endpoint, |
|||
} |
|||
// Redirect user to Google's consent page to ask for permission
|
|||
// for the scopes specified above.
|
|||
url := conf.AuthCodeURL("state") |
|||
fmt.Printf("Visit the URL for the auth dialog: %v", url) |
|||
|
|||
// Handle the exchange code to initiate a transport.
|
|||
tok, err := conf.Exchange(oauth2.NoContext, "authorization-code") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
client := conf.Client(oauth2.NoContext, tok) |
|||
client.Get("...") |
|||
} |
|||
|
|||
func ExampleJWTConfigFromJSON() { |
|||
// Your credentials should be obtained from the Google
|
|||
// Developer Console (https://console.developers.google.com).
|
|||
// Navigate to your project, then see the "Credentials" page
|
|||
// under "APIs & Auth".
|
|||
// To create a service account client, click "Create new Client ID",
|
|||
// select "Service Account", and click "Create Client ID". A JSON
|
|||
// key file will then be downloaded to your computer.
|
|||
data, err := ioutil.ReadFile("/path/to/your-project-key.json") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
conf, err := google.JWTConfigFromJSON(data, "https://www.googleapis.com/auth/bigquery") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
// Initiate an http.Client. The following GET request will be
|
|||
// authorized and authenticated on the behalf of
|
|||
// your service account.
|
|||
client := conf.Client(oauth2.NoContext) |
|||
client.Get("...") |
|||
} |
|||
|
|||
func ExampleSDKConfig() { |
|||
// The credentials will be obtained from the first account that
|
|||
// has been authorized with `gcloud auth login`.
|
|||
conf, err := google.NewSDKConfig("") |
|||
if err != nil { |
|||
log.Fatal(err) |
|||
} |
|||
// Initiate an http.Client. The following GET request will be
|
|||
// authorized and authenticated on the behalf of the SDK user.
|
|||
client := conf.Client(oauth2.NoContext) |
|||
client.Get("...") |
|||
} |
|||
|
|||
func Example_serviceAccount() { |
|||
// Your credentials should be obtained from the Google
|
|||
// Developer Console (https://console.developers.google.com).
|
|||
conf := &jwt.Config{ |
|||
Email: "xxx@developer.gserviceaccount.com", |
|||
// The contents of your RSA private key or your PEM file
|
|||
// that contains a private key.
|
|||
// If you have a p12 file instead, you
|
|||
// can use `openssl` to export the private key into a pem file.
|
|||
//
|
|||
// $ openssl pkcs12 -in key.p12 -passin pass:notasecret -out key.pem -nodes
|
|||
//
|
|||
// The field only supports PEM containers with no passphrase.
|
|||
// The openssl command will convert p12 keys to passphrase-less PEM containers.
|
|||
PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), |
|||
Scopes: []string{ |
|||
"https://www.googleapis.com/auth/bigquery", |
|||
"https://www.googleapis.com/auth/blogger", |
|||
}, |
|||
TokenURL: google.JWTTokenURL, |
|||
// If you would like to impersonate a user, you can
|
|||
// create a transport with a subject. The following GET
|
|||
// request will be made on the behalf of user@example.com.
|
|||
// Optional.
|
|||
Subject: "user@example.com", |
|||
} |
|||
// Initiate an http.Client, the following GET request will be
|
|||
// authorized and authenticated on the behalf of user@example.com.
|
|||
client := conf.Client(oauth2.NoContext) |
|||
client.Get("...") |
|||
} |
|||
|
|||
func ExampleAppEngineTokenSource() { |
|||
var req *http.Request // from the ServeHTTP handler
|
|||
ctx := appengine.NewContext(req) |
|||
client := &http.Client{ |
|||
Transport: &oauth2.Transport{ |
|||
Source: google.AppEngineTokenSource(ctx, "https://www.googleapis.com/auth/bigquery"), |
|||
Base: &urlfetch.Transport{ |
|||
Context: ctx, |
|||
}, |
|||
}, |
|||
} |
|||
client.Get("...") |
|||
} |
|||
|
|||
func ExampleComputeTokenSource() { |
|||
client := &http.Client{ |
|||
Transport: &oauth2.Transport{ |
|||
// Fetch from Google Compute Engine's metadata server to retrieve
|
|||
// an access token for the provided account.
|
|||
// If no account is specified, "default" is used.
|
|||
Source: google.ComputeTokenSource(""), |
|||
}, |
|||
} |
|||
client.Get("...") |
|||
} |
@ -0,0 +1,149 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package google provides support for making OAuth2 authorized and
|
|||
// authenticated HTTP requests to Google APIs.
|
|||
// It supports the Web server flow, client-side credentials, service accounts,
|
|||
// Google Compute Engine service accounts, and Google App Engine service
|
|||
// accounts.
|
|||
//
|
|||
// For more information, please read
|
|||
// https://developers.google.com/accounts/docs/OAuth2
|
|||
// and
|
|||
// https://developers.google.com/accounts/docs/application-default-credentials.
|
|||
package google // import "golang.org/x/oauth2/google"
|
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"strings" |
|||
"time" |
|||
|
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/jwt" |
|||
"google.golang.org/cloud/compute/metadata" |
|||
) |
|||
|
|||
// Endpoint is Google's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://accounts.google.com/o/oauth2/auth", |
|||
TokenURL: "https://accounts.google.com/o/oauth2/token", |
|||
} |
|||
|
|||
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
|
|||
const JWTTokenURL = "https://accounts.google.com/o/oauth2/token" |
|||
|
|||
// ConfigFromJSON uses a Google Developers Console client_credentials.json
|
|||
// file to construct a config.
|
|||
// client_credentials.json can be downloaded from
|
|||
// https://console.developers.google.com, under "Credentials". Download the Web
|
|||
// application credentials in the JSON format and provide the contents of the
|
|||
// file as jsonKey.
|
|||
func ConfigFromJSON(jsonKey []byte, scope ...string) (*oauth2.Config, error) { |
|||
type cred struct { |
|||
ClientID string `json:"client_id"` |
|||
ClientSecret string `json:"client_secret"` |
|||
RedirectURIs []string `json:"redirect_uris"` |
|||
AuthURI string `json:"auth_uri"` |
|||
TokenURI string `json:"token_uri"` |
|||
} |
|||
var j struct { |
|||
Web *cred `json:"web"` |
|||
Installed *cred `json:"installed"` |
|||
} |
|||
if err := json.Unmarshal(jsonKey, &j); err != nil { |
|||
return nil, err |
|||
} |
|||
var c *cred |
|||
switch { |
|||
case j.Web != nil: |
|||
c = j.Web |
|||
case j.Installed != nil: |
|||
c = j.Installed |
|||
default: |
|||
return nil, fmt.Errorf("oauth2/google: no credentials found") |
|||
} |
|||
if len(c.RedirectURIs) < 1 { |
|||
return nil, errors.New("oauth2/google: missing redirect URL in the client_credentials.json") |
|||
} |
|||
return &oauth2.Config{ |
|||
ClientID: c.ClientID, |
|||
ClientSecret: c.ClientSecret, |
|||
RedirectURL: c.RedirectURIs[0], |
|||
Scopes: scope, |
|||
Endpoint: oauth2.Endpoint{ |
|||
AuthURL: c.AuthURI, |
|||
TokenURL: c.TokenURI, |
|||
}, |
|||
}, nil |
|||
} |
|||
|
|||
// JWTConfigFromJSON uses a Google Developers service account JSON key file to read
|
|||
// the credentials that authorize and authenticate the requests.
|
|||
// Create a service account on "Credentials" for your project at
|
|||
// https://console.developers.google.com to download a JSON key file.
|
|||
func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) { |
|||
var key struct { |
|||
Email string `json:"client_email"` |
|||
PrivateKey string `json:"private_key"` |
|||
PrivateKeyID string `json:"private_key_id"` |
|||
} |
|||
if err := json.Unmarshal(jsonKey, &key); err != nil { |
|||
return nil, err |
|||
} |
|||
config := &jwt.Config{ |
|||
Email: key.Email, |
|||
PrivateKey: []byte(key.PrivateKey), |
|||
PrivateKeyID: key.PrivateKeyID, |
|||
Scopes: scope, |
|||
TokenURL: JWTTokenURL, |
|||
} |
|||
return config, nil |
|||
} |
|||
|
|||
// ComputeTokenSource returns a token source that fetches access tokens
|
|||
// from Google Compute Engine (GCE)'s metadata server. It's only valid to use
|
|||
// this token source if your program is running on a GCE instance.
|
|||
// If no account is specified, "default" is used.
|
|||
// Further information about retrieving access tokens from the GCE metadata
|
|||
// server can be found at https://cloud.google.com/compute/docs/authentication.
|
|||
func ComputeTokenSource(account string) oauth2.TokenSource { |
|||
return oauth2.ReuseTokenSource(nil, computeSource{account: account}) |
|||
} |
|||
|
|||
type computeSource struct { |
|||
account string |
|||
} |
|||
|
|||
func (cs computeSource) Token() (*oauth2.Token, error) { |
|||
if !metadata.OnGCE() { |
|||
return nil, errors.New("oauth2/google: can't get a token from the metadata service; not running on GCE") |
|||
} |
|||
acct := cs.account |
|||
if acct == "" { |
|||
acct = "default" |
|||
} |
|||
tokenJSON, err := metadata.Get("instance/service-accounts/" + acct + "/token") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
var res struct { |
|||
AccessToken string `json:"access_token"` |
|||
ExpiresInSec int `json:"expires_in"` |
|||
TokenType string `json:"token_type"` |
|||
} |
|||
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err) |
|||
} |
|||
if res.ExpiresInSec == 0 || res.AccessToken == "" { |
|||
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata") |
|||
} |
|||
return &oauth2.Token{ |
|||
AccessToken: res.AccessToken, |
|||
TokenType: res.TokenType, |
|||
Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second), |
|||
}, nil |
|||
} |
@ -0,0 +1,97 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
var webJSONKey = []byte(` |
|||
{ |
|||
"web": { |
|||
"auth_uri": "https://google.com/o/oauth2/auth", |
|||
"client_secret": "3Oknc4jS_wA2r9i", |
|||
"token_uri": "https://google.com/o/oauth2/token", |
|||
"client_email": "222-nprqovg5k43uum874cs9osjt2koe97g8@developer.gserviceaccount.com", |
|||
"redirect_uris": ["https://www.example.com/oauth2callback"], |
|||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/222-nprqovg5k43uum874cs9osjt2koe97g8@developer.gserviceaccount.com", |
|||
"client_id": "222-nprqovg5k43uum874cs9osjt2koe97g8.apps.googleusercontent.com", |
|||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
|||
"javascript_origins": ["https://www.example.com"] |
|||
} |
|||
}`) |
|||
|
|||
var installedJSONKey = []byte(`{ |
|||
"installed": { |
|||
"client_id": "222-installed.apps.googleusercontent.com", |
|||
"redirect_uris": ["https://www.example.com/oauth2callback"] |
|||
} |
|||
}`) |
|||
|
|||
var jwtJSONKey = []byte(`{ |
|||
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b", |
|||
"private_key": "super secret key", |
|||
"client_email": "gopher@developer.gserviceaccount.com", |
|||
"client_id": "gopher.apps.googleusercontent.com", |
|||
"type": "service_account" |
|||
}`) |
|||
|
|||
func TestConfigFromJSON(t *testing.T) { |
|||
conf, err := ConfigFromJSON(webJSONKey, "scope1", "scope2") |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if got, want := conf.ClientID, "222-nprqovg5k43uum874cs9osjt2koe97g8.apps.googleusercontent.com"; got != want { |
|||
t.Errorf("ClientID = %q; want %q", got, want) |
|||
} |
|||
if got, want := conf.ClientSecret, "3Oknc4jS_wA2r9i"; got != want { |
|||
t.Errorf("ClientSecret = %q; want %q", got, want) |
|||
} |
|||
if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want { |
|||
t.Errorf("RedictURL = %q; want %q", got, want) |
|||
} |
|||
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want { |
|||
t.Errorf("Scopes = %q; want %q", got, want) |
|||
} |
|||
if got, want := conf.Endpoint.AuthURL, "https://google.com/o/oauth2/auth"; got != want { |
|||
t.Errorf("AuthURL = %q; want %q", got, want) |
|||
} |
|||
if got, want := conf.Endpoint.TokenURL, "https://google.com/o/oauth2/token"; got != want { |
|||
t.Errorf("TokenURL = %q; want %q", got, want) |
|||
} |
|||
} |
|||
|
|||
func TestConfigFromJSON_Installed(t *testing.T) { |
|||
conf, err := ConfigFromJSON(installedJSONKey) |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if got, want := conf.ClientID, "222-installed.apps.googleusercontent.com"; got != want { |
|||
t.Errorf("ClientID = %q; want %q", got, want) |
|||
} |
|||
} |
|||
|
|||
func TestJWTConfigFromJSON(t *testing.T) { |
|||
conf, err := JWTConfigFromJSON(jwtJSONKey, "scope1", "scope2") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if got, want := conf.Email, "gopher@developer.gserviceaccount.com"; got != want { |
|||
t.Errorf("Email = %q, want %q", got, want) |
|||
} |
|||
if got, want := string(conf.PrivateKey), "super secret key"; got != want { |
|||
t.Errorf("PrivateKey = %q, want %q", got, want) |
|||
} |
|||
if got, want := conf.PrivateKeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { |
|||
t.Errorf("PrivateKeyID = %q, want %q", got, want) |
|||
} |
|||
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want { |
|||
t.Errorf("Scopes = %q; want %q", got, want) |
|||
} |
|||
if got, want := conf.TokenURL, "https://accounts.google.com/o/oauth2/token"; got != want { |
|||
t.Errorf("TokenURL = %q; want %q", got, want) |
|||
} |
|||
} |
@ -0,0 +1,74 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"crypto/rsa" |
|||
"fmt" |
|||
"time" |
|||
|
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/internal" |
|||
"golang.org/x/oauth2/jws" |
|||
) |
|||
|
|||
// JWTAccessTokenSourceFromJSON uses a Google Developers service account JSON
|
|||
// key file to read the credentials that authorize and authenticate the
|
|||
// requests, and returns a TokenSource that does not use any OAuth2 flow but
|
|||
// instead creates a JWT and sends that as the access token.
|
|||
// The audience is typically a URL that specifies the scope of the credentials.
|
|||
//
|
|||
// Note that this is not a standard OAuth flow, but rather an
|
|||
// optimization supported by a few Google services.
|
|||
// Unless you know otherwise, you should use JWTConfigFromJSON instead.
|
|||
func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) { |
|||
cfg, err := JWTConfigFromJSON(jsonKey) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("google: could not parse JSON key: %v", err) |
|||
} |
|||
pk, err := internal.ParseKey(cfg.PrivateKey) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("google: could not parse key: %v", err) |
|||
} |
|||
ts := &jwtAccessTokenSource{ |
|||
email: cfg.Email, |
|||
audience: audience, |
|||
pk: pk, |
|||
pkID: cfg.PrivateKeyID, |
|||
} |
|||
tok, err := ts.Token() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return oauth2.ReuseTokenSource(tok, ts), nil |
|||
} |
|||
|
|||
type jwtAccessTokenSource struct { |
|||
email, audience string |
|||
pk *rsa.PrivateKey |
|||
pkID string |
|||
} |
|||
|
|||
func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) { |
|||
iat := time.Now() |
|||
exp := iat.Add(time.Hour) |
|||
cs := &jws.ClaimSet{ |
|||
Iss: ts.email, |
|||
Sub: ts.email, |
|||
Aud: ts.audience, |
|||
Iat: iat.Unix(), |
|||
Exp: exp.Unix(), |
|||
} |
|||
hdr := &jws.Header{ |
|||
Algorithm: "RS256", |
|||
Typ: "JWT", |
|||
KeyID: string(ts.pkID), |
|||
} |
|||
msg, err := jws.Encode(hdr, cs, ts.pk) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("google: could not encode JWT: %v", err) |
|||
} |
|||
return &oauth2.Token{AccessToken: msg, TokenType: "Bearer", Expiry: exp}, nil |
|||
} |
@ -0,0 +1,91 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/x509" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"encoding/pem" |
|||
"strings" |
|||
"testing" |
|||
"time" |
|||
|
|||
"golang.org/x/oauth2/jws" |
|||
) |
|||
|
|||
func TestJWTAccessTokenSourceFromJSON(t *testing.T) { |
|||
// Generate a key we can use in the test data.
|
|||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
// Encode the key and substitute into our example JSON.
|
|||
enc := pem.EncodeToMemory(&pem.Block{ |
|||
Type: "PRIVATE KEY", |
|||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey), |
|||
}) |
|||
enc, err = json.Marshal(string(enc)) |
|||
if err != nil { |
|||
t.Fatalf("json.Marshal: %v", err) |
|||
} |
|||
jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1) |
|||
|
|||
ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience") |
|||
if err != nil { |
|||
t.Fatalf("JWTAccessTokenSourceFromJSON: %v\nJSON: %s", err, string(jsonKey)) |
|||
} |
|||
|
|||
tok, err := ts.Token() |
|||
if err != nil { |
|||
t.Fatalf("Token: %v", err) |
|||
} |
|||
|
|||
if got, want := tok.TokenType, "Bearer"; got != want { |
|||
t.Errorf("TokenType = %q, want %q", got, want) |
|||
} |
|||
if got := tok.Expiry; tok.Expiry.Before(time.Now()) { |
|||
t.Errorf("Expiry = %v, should not be expired", got) |
|||
} |
|||
|
|||
err = jws.Verify(tok.AccessToken, &privateKey.PublicKey) |
|||
if err != nil { |
|||
t.Errorf("jws.Verify on AccessToken: %v", err) |
|||
} |
|||
|
|||
claim, err := jws.Decode(tok.AccessToken) |
|||
if err != nil { |
|||
t.Fatalf("jws.Decode on AccessToken: %v", err) |
|||
} |
|||
|
|||
if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want { |
|||
t.Errorf("Iss = %q, want %q", got, want) |
|||
} |
|||
if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want { |
|||
t.Errorf("Sub = %q, want %q", got, want) |
|||
} |
|||
if got, want := claim.Aud, "audience"; got != want { |
|||
t.Errorf("Aud = %q, want %q", got, want) |
|||
} |
|||
|
|||
// Finally, check the header private key.
|
|||
parts := strings.Split(tok.AccessToken, ".") |
|||
hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) |
|||
if err != nil { |
|||
t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0]) |
|||
} |
|||
var hdr jws.Header |
|||
if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil { |
|||
t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON) |
|||
} |
|||
|
|||
if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { |
|||
t.Errorf("Header KeyID = %q, want %q", got, want) |
|||
} |
|||
} |
@ -0,0 +1,168 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"net/http" |
|||
"os" |
|||
"os/user" |
|||
"path/filepath" |
|||
"runtime" |
|||
"strings" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/internal" |
|||
) |
|||
|
|||
type sdkCredentials struct { |
|||
Data []struct { |
|||
Credential struct { |
|||
ClientID string `json:"client_id"` |
|||
ClientSecret string `json:"client_secret"` |
|||
AccessToken string `json:"access_token"` |
|||
RefreshToken string `json:"refresh_token"` |
|||
TokenExpiry *time.Time `json:"token_expiry"` |
|||
} `json:"credential"` |
|||
Key struct { |
|||
Account string `json:"account"` |
|||
Scope string `json:"scope"` |
|||
} `json:"key"` |
|||
} |
|||
} |
|||
|
|||
// An SDKConfig provides access to tokens from an account already
|
|||
// authorized via the Google Cloud SDK.
|
|||
type SDKConfig struct { |
|||
conf oauth2.Config |
|||
initialToken *oauth2.Token |
|||
} |
|||
|
|||
// NewSDKConfig creates an SDKConfig for the given Google Cloud SDK
|
|||
// account. If account is empty, the account currently active in
|
|||
// Google Cloud SDK properties is used.
|
|||
// Google Cloud SDK credentials must be created by running `gcloud auth`
|
|||
// before using this function.
|
|||
// The Google Cloud SDK is available at https://cloud.google.com/sdk/.
|
|||
func NewSDKConfig(account string) (*SDKConfig, error) { |
|||
configPath, err := sdkConfigPath() |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: error getting SDK config path: %v", err) |
|||
} |
|||
credentialsPath := filepath.Join(configPath, "credentials") |
|||
f, err := os.Open(credentialsPath) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: failed to load SDK credentials: %v", err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
var c sdkCredentials |
|||
if err := json.NewDecoder(f).Decode(&c); err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: failed to decode SDK credentials from %q: %v", credentialsPath, err) |
|||
} |
|||
if len(c.Data) == 0 { |
|||
return nil, fmt.Errorf("oauth2/google: no credentials found in %q, run `gcloud auth login` to create one", credentialsPath) |
|||
} |
|||
if account == "" { |
|||
propertiesPath := filepath.Join(configPath, "properties") |
|||
f, err := os.Open(propertiesPath) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: failed to load SDK properties: %v", err) |
|||
} |
|||
defer f.Close() |
|||
ini, err := internal.ParseINI(f) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2/google: failed to parse SDK properties %q: %v", propertiesPath, err) |
|||
} |
|||
core, ok := ini["core"] |
|||
if !ok { |
|||
return nil, fmt.Errorf("oauth2/google: failed to find [core] section in %v", ini) |
|||
} |
|||
active, ok := core["account"] |
|||
if !ok { |
|||
return nil, fmt.Errorf("oauth2/google: failed to find %q attribute in %v", "account", core) |
|||
} |
|||
account = active |
|||
} |
|||
|
|||
for _, d := range c.Data { |
|||
if account == "" || d.Key.Account == account { |
|||
if d.Credential.AccessToken == "" && d.Credential.RefreshToken == "" { |
|||
return nil, fmt.Errorf("oauth2/google: no token available for account %q", account) |
|||
} |
|||
var expiry time.Time |
|||
if d.Credential.TokenExpiry != nil { |
|||
expiry = *d.Credential.TokenExpiry |
|||
} |
|||
return &SDKConfig{ |
|||
conf: oauth2.Config{ |
|||
ClientID: d.Credential.ClientID, |
|||
ClientSecret: d.Credential.ClientSecret, |
|||
Scopes: strings.Split(d.Key.Scope, " "), |
|||
Endpoint: Endpoint, |
|||
RedirectURL: "oob", |
|||
}, |
|||
initialToken: &oauth2.Token{ |
|||
AccessToken: d.Credential.AccessToken, |
|||
RefreshToken: d.Credential.RefreshToken, |
|||
Expiry: expiry, |
|||
}, |
|||
}, nil |
|||
} |
|||
} |
|||
return nil, fmt.Errorf("oauth2/google: no such credentials for account %q", account) |
|||
} |
|||
|
|||
// Client returns an HTTP client using Google Cloud SDK credentials to
|
|||
// authorize requests. The token will auto-refresh as necessary. The
|
|||
// underlying http.RoundTripper will be obtained using the provided
|
|||
// context. The returned client and its Transport should not be
|
|||
// modified.
|
|||
func (c *SDKConfig) Client(ctx context.Context) *http.Client { |
|||
return &http.Client{ |
|||
Transport: &oauth2.Transport{ |
|||
Source: c.TokenSource(ctx), |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// TokenSource returns an oauth2.TokenSource that retrieve tokens from
|
|||
// Google Cloud SDK credentials using the provided context.
|
|||
// It will returns the current access token stored in the credentials,
|
|||
// and refresh it when it expires, but it won't update the credentials
|
|||
// with the new access token.
|
|||
func (c *SDKConfig) TokenSource(ctx context.Context) oauth2.TokenSource { |
|||
return c.conf.TokenSource(ctx, c.initialToken) |
|||
} |
|||
|
|||
// Scopes are the OAuth 2.0 scopes the current account is authorized for.
|
|||
func (c *SDKConfig) Scopes() []string { |
|||
return c.conf.Scopes |
|||
} |
|||
|
|||
// sdkConfigPath tries to guess where the gcloud config is located.
|
|||
// It can be overridden during tests.
|
|||
var sdkConfigPath = func() (string, error) { |
|||
if runtime.GOOS == "windows" { |
|||
return filepath.Join(os.Getenv("APPDATA"), "gcloud"), nil |
|||
} |
|||
homeDir := guessUnixHomeDir() |
|||
if homeDir == "" { |
|||
return "", errors.New("unable to get current user home directory: os/user lookup failed; $HOME is empty") |
|||
} |
|||
return filepath.Join(homeDir, ".config", "gcloud"), nil |
|||
} |
|||
|
|||
func guessUnixHomeDir() string { |
|||
usr, err := user.Current() |
|||
if err == nil { |
|||
return usr.HomeDir |
|||
} |
|||
return os.Getenv("HOME") |
|||
} |
@ -0,0 +1,46 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package google |
|||
|
|||
import "testing" |
|||
|
|||
func TestSDKConfig(t *testing.T) { |
|||
sdkConfigPath = func() (string, error) { |
|||
return "testdata/gcloud", nil |
|||
} |
|||
|
|||
tests := []struct { |
|||
account string |
|||
accessToken string |
|||
err bool |
|||
}{ |
|||
{"", "bar_access_token", false}, |
|||
{"foo@example.com", "foo_access_token", false}, |
|||
{"bar@example.com", "bar_access_token", false}, |
|||
{"baz@serviceaccount.example.com", "", true}, |
|||
} |
|||
for _, tt := range tests { |
|||
c, err := NewSDKConfig(tt.account) |
|||
if got, want := err != nil, tt.err; got != want { |
|||
if !tt.err { |
|||
t.Errorf("got %v, want nil", err) |
|||
} else { |
|||
t.Errorf("got nil, want error") |
|||
} |
|||
continue |
|||
} |
|||
if err != nil { |
|||
continue |
|||
} |
|||
tok := c.initialToken |
|||
if tok == nil { |
|||
t.Errorf("got nil, want %q", tt.accessToken) |
|||
continue |
|||
} |
|||
if tok.AccessToken != tt.accessToken { |
|||
t.Errorf("got %q, want %q", tok.AccessToken, tt.accessToken) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,122 @@ |
|||
{ |
|||
"data": [ |
|||
{ |
|||
"credential": { |
|||
"_class": "OAuth2Credentials", |
|||
"_module": "oauth2client.client", |
|||
"access_token": "foo_access_token", |
|||
"client_id": "foo_client_id", |
|||
"client_secret": "foo_client_secret", |
|||
"id_token": { |
|||
"at_hash": "foo_at_hash", |
|||
"aud": "foo_aud", |
|||
"azp": "foo_azp", |
|||
"cid": "foo_cid", |
|||
"email": "foo@example.com", |
|||
"email_verified": true, |
|||
"exp": 1420573614, |
|||
"iat": 1420569714, |
|||
"id": "1337", |
|||
"iss": "accounts.google.com", |
|||
"sub": "1337", |
|||
"token_hash": "foo_token_hash", |
|||
"verified_email": true |
|||
}, |
|||
"invalid": false, |
|||
"refresh_token": "foo_refresh_token", |
|||
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
|||
"token_expiry": "2015-01-09T00:51:51Z", |
|||
"token_response": { |
|||
"access_token": "foo_access_token", |
|||
"expires_in": 3600, |
|||
"id_token": "foo_id_token", |
|||
"token_type": "Bearer" |
|||
}, |
|||
"token_uri": "https://accounts.google.com/o/oauth2/token", |
|||
"user_agent": "Cloud SDK Command Line Tool" |
|||
}, |
|||
"key": { |
|||
"account": "foo@example.com", |
|||
"clientId": "foo_client_id", |
|||
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
|||
"type": "google-cloud-sdk" |
|||
} |
|||
}, |
|||
{ |
|||
"credential": { |
|||
"_class": "OAuth2Credentials", |
|||
"_module": "oauth2client.client", |
|||
"access_token": "bar_access_token", |
|||
"client_id": "bar_client_id", |
|||
"client_secret": "bar_client_secret", |
|||
"id_token": { |
|||
"at_hash": "bar_at_hash", |
|||
"aud": "bar_aud", |
|||
"azp": "bar_azp", |
|||
"cid": "bar_cid", |
|||
"email": "bar@example.com", |
|||
"email_verified": true, |
|||
"exp": 1420573614, |
|||
"iat": 1420569714, |
|||
"id": "1337", |
|||
"iss": "accounts.google.com", |
|||
"sub": "1337", |
|||
"token_hash": "bar_token_hash", |
|||
"verified_email": true |
|||
}, |
|||
"invalid": false, |
|||
"refresh_token": "bar_refresh_token", |
|||
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
|||
"token_expiry": "2015-01-09T00:51:51Z", |
|||
"token_response": { |
|||
"access_token": "bar_access_token", |
|||
"expires_in": 3600, |
|||
"id_token": "bar_id_token", |
|||
"token_type": "Bearer" |
|||
}, |
|||
"token_uri": "https://accounts.google.com/o/oauth2/token", |
|||
"user_agent": "Cloud SDK Command Line Tool" |
|||
}, |
|||
"key": { |
|||
"account": "bar@example.com", |
|||
"clientId": "bar_client_id", |
|||
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
|||
"type": "google-cloud-sdk" |
|||
} |
|||
}, |
|||
{ |
|||
"credential": { |
|||
"_class": "ServiceAccountCredentials", |
|||
"_kwargs": {}, |
|||
"_module": "oauth2client.client", |
|||
"_private_key_id": "00000000000000000000000000000000", |
|||
"_private_key_pkcs8_text": "-----BEGIN RSA PRIVATE KEY-----\nMIICWwIBAAKBgQCt3fpiynPSaUhWSIKMGV331zudwJ6GkGmvQtwsoK2S2LbvnSwU\nNxgj4fp08kIDR5p26wF4+t/HrKydMwzftXBfZ9UmLVJgRdSswmS5SmChCrfDS5OE\nvFFcN5+6w1w8/Nu657PF/dse8T0bV95YrqyoR0Osy8WHrUOMSIIbC3hRuwIDAQAB\nAoGAJrGE/KFjn0sQ7yrZ6sXmdLawrM3mObo/2uI9T60+k7SpGbBX0/Pi6nFrJMWZ\nTVONG7P3Mu5aCPzzuVRYJB0j8aldSfzABTY3HKoWCczqw1OztJiEseXGiYz4QOyr\nYU3qDyEpdhS6q6wcoLKGH+hqRmz6pcSEsc8XzOOu7s4xW8kCQQDkc75HjhbarCnd\nJJGMe3U76+6UGmdK67ltZj6k6xoB5WbTNChY9TAyI2JC+ppYV89zv3ssj4L+02u3\nHIHFGxsHAkEAwtU1qYb1tScpchPobnYUFiVKJ7KA8EZaHVaJJODW/cghTCV7BxcJ\nbgVvlmk4lFKn3lPKAgWw7PdQsBTVBUcCrQJATPwoIirizrv3u5soJUQxZIkENAqV\nxmybZx9uetrzP7JTrVbFRf0SScMcyN90hdLJiQL8+i4+gaszgFht7sNMnwJAAbfj\nq0UXcauQwALQ7/h2oONfTg5S+MuGC/AxcXPSMZbMRGGoPh3D5YaCv27aIuS/ukQ+\n6dmm/9AGlCb64fsIWQJAPaokbjIifo+LwC5gyK73Mc4t8nAOSZDenzd/2f6TCq76\nS1dcnKiPxaED7W/y6LJiuBT2rbZiQ2L93NJpFZD/UA==\n-----END RSA PRIVATE KEY-----\n", |
|||
"_revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
|||
"_scopes": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
|||
"_service_account_email": "baz@serviceaccount.example.com", |
|||
"_service_account_id": "baz.serviceaccount.example.com", |
|||
"_token_uri": "https://accounts.google.com/o/oauth2/token", |
|||
"_user_agent": "Cloud SDK Command Line Tool", |
|||
"access_token": null, |
|||
"assertion_type": null, |
|||
"client_id": null, |
|||
"client_secret": null, |
|||
"id_token": null, |
|||
"invalid": false, |
|||
"refresh_token": null, |
|||
"revoke_uri": "https://accounts.google.com/o/oauth2/revoke", |
|||
"service_account_name": "baz@serviceaccount.example.com", |
|||
"token_expiry": null, |
|||
"token_response": null, |
|||
"user_agent": "Cloud SDK Command Line Tool" |
|||
}, |
|||
"key": { |
|||
"account": "baz@serviceaccount.example.com", |
|||
"clientId": "baz_client_id", |
|||
"scope": "https://www.googleapis.com/auth/appengine.admin https://www.googleapis.com/auth/bigquery https://www.googleapis.com/auth/compute https://www.googleapis.com/auth/devstorage.full_control https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/ndev.cloudman https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/sqlservice.admin https://www.googleapis.com/auth/prediction https://www.googleapis.com/auth/projecthosting", |
|||
"type": "google-cloud-sdk" |
|||
} |
|||
} |
|||
], |
|||
"file_version": 1 |
|||
} |
@ -0,0 +1,2 @@ |
|||
[core] |
|||
account = bar@example.com |
@ -0,0 +1,60 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package hipchat provides constants for using OAuth2 to access HipChat.
|
|||
package hipchat // import "golang.org/x/oauth2/hipchat"
|
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
|
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/clientcredentials" |
|||
) |
|||
|
|||
// Endpoint is HipChat's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.hipchat.com/users/authorize", |
|||
TokenURL: "https://api.hipchat.com/v2/oauth/token", |
|||
} |
|||
|
|||
// ServerEndpoint returns a new oauth2.Endpoint for a HipChat Server instance
|
|||
// running on the given domain or host.
|
|||
func ServerEndpoint(host string) oauth2.Endpoint { |
|||
return oauth2.Endpoint{ |
|||
AuthURL: "https://" + host + "/users/authorize", |
|||
TokenURL: "https://" + host + "/v2/oauth/token", |
|||
} |
|||
} |
|||
|
|||
// ClientCredentialsConfigFromCaps generates a Config from a HipChat API
|
|||
// capabilities descriptor. It does not verify the scopes against the
|
|||
// capabilities document at this time.
|
|||
//
|
|||
// For more information see: https://www.hipchat.com/docs/apiv2/method/get_capabilities
|
|||
func ClientCredentialsConfigFromCaps(capsJSON []byte, clientID, clientSecret string, scopes ...string) (*clientcredentials.Config, error) { |
|||
var caps struct { |
|||
Caps struct { |
|||
Endpoint struct { |
|||
TokenURL string `json:"tokenUrl"` |
|||
} `json:"oauth2Provider"` |
|||
} `json:"capabilities"` |
|||
} |
|||
|
|||
if err := json.Unmarshal(capsJSON, &caps); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// Verify required fields.
|
|||
if caps.Caps.Endpoint.TokenURL == "" { |
|||
return nil, errors.New("oauth2/hipchat: missing OAuth2 token URL in the capabilities descriptor JSON") |
|||
} |
|||
|
|||
return &clientcredentials.Config{ |
|||
ClientID: clientID, |
|||
ClientSecret: clientSecret, |
|||
Scopes: scopes, |
|||
TokenURL: caps.Caps.Endpoint.TokenURL, |
|||
}, nil |
|||
} |
@ -0,0 +1,76 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package internal contains support packages for oauth2 package.
|
|||
package internal |
|||
|
|||
import ( |
|||
"bufio" |
|||
"crypto/rsa" |
|||
"crypto/x509" |
|||
"encoding/pem" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"strings" |
|||
) |
|||
|
|||
// ParseKey converts the binary contents of a private key file
|
|||
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
|||
// PEM container or not. If so, it extracts the the private key
|
|||
// from PEM container before conversion. It only supports PEM
|
|||
// containers with no passphrase.
|
|||
func ParseKey(key []byte) (*rsa.PrivateKey, error) { |
|||
block, _ := pem.Decode(key) |
|||
if block != nil { |
|||
key = block.Bytes |
|||
} |
|||
parsedKey, err := x509.ParsePKCS8PrivateKey(key) |
|||
if err != nil { |
|||
parsedKey, err = x509.ParsePKCS1PrivateKey(key) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) |
|||
} |
|||
} |
|||
parsed, ok := parsedKey.(*rsa.PrivateKey) |
|||
if !ok { |
|||
return nil, errors.New("private key is invalid") |
|||
} |
|||
return parsed, nil |
|||
} |
|||
|
|||
func ParseINI(ini io.Reader) (map[string]map[string]string, error) { |
|||
result := map[string]map[string]string{ |
|||
"": map[string]string{}, // root section
|
|||
} |
|||
scanner := bufio.NewScanner(ini) |
|||
currentSection := "" |
|||
for scanner.Scan() { |
|||
line := strings.TrimSpace(scanner.Text()) |
|||
if strings.HasPrefix(line, ";") { |
|||
// comment.
|
|||
continue |
|||
} |
|||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { |
|||
currentSection = strings.TrimSpace(line[1 : len(line)-1]) |
|||
result[currentSection] = map[string]string{} |
|||
continue |
|||
} |
|||
parts := strings.SplitN(line, "=", 2) |
|||
if len(parts) == 2 && parts[0] != "" { |
|||
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) |
|||
} |
|||
} |
|||
if err := scanner.Err(); err != nil { |
|||
return nil, fmt.Errorf("error scanning ini: %v", err) |
|||
} |
|||
return result, nil |
|||
} |
|||
|
|||
func CondVal(v string) []string { |
|||
if v == "" { |
|||
return nil |
|||
} |
|||
return []string{v} |
|||
} |
@ -0,0 +1,62 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package internal contains support packages for oauth2 package.
|
|||
package internal |
|||
|
|||
import ( |
|||
"reflect" |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
func TestParseINI(t *testing.T) { |
|||
tests := []struct { |
|||
ini string |
|||
want map[string]map[string]string |
|||
}{ |
|||
{ |
|||
`root = toor |
|||
[foo] |
|||
bar = hop |
|||
ini = nin |
|||
`, |
|||
map[string]map[string]string{ |
|||
"": map[string]string{"root": "toor"}, |
|||
"foo": map[string]string{"bar": "hop", "ini": "nin"}, |
|||
}, |
|||
}, |
|||
{ |
|||
`[empty] |
|||
[section] |
|||
empty= |
|||
`, |
|||
map[string]map[string]string{ |
|||
"": map[string]string{}, |
|||
"empty": map[string]string{}, |
|||
"section": map[string]string{"empty": ""}, |
|||
}, |
|||
}, |
|||
{ |
|||
`ignore |
|||
[invalid |
|||
=stuff |
|||
;comment=true |
|||
`, |
|||
map[string]map[string]string{ |
|||
"": map[string]string{}, |
|||
}, |
|||
}, |
|||
} |
|||
for _, tt := range tests { |
|||
result, err := ParseINI(strings.NewReader(tt.ini)) |
|||
if err != nil { |
|||
t.Errorf("ParseINI(%q) error %v, want: no error", tt.ini, err) |
|||
continue |
|||
} |
|||
if !reflect.DeepEqual(result, tt.want) { |
|||
t.Errorf("ParseINI(%q) = %#v, want: %#v", tt.ini, result, tt.want) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,225 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package internal contains support packages for oauth2 package.
|
|||
package internal |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"mime" |
|||
"net/http" |
|||
"net/url" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
// Token represents the crendentials used to authorize
|
|||
// the requests to access protected resources on the OAuth 2.0
|
|||
// provider's backend.
|
|||
//
|
|||
// This type is a mirror of oauth2.Token and exists to break
|
|||
// an otherwise-circular dependency. Other internal packages
|
|||
// should convert this Token into an oauth2.Token before use.
|
|||
type Token struct { |
|||
// AccessToken is the token that authorizes and authenticates
|
|||
// the requests.
|
|||
AccessToken string |
|||
|
|||
// TokenType is the type of token.
|
|||
// The Type method returns either this or "Bearer", the default.
|
|||
TokenType string |
|||
|
|||
// RefreshToken is a token that's used by the application
|
|||
// (as opposed to the user) to refresh the access token
|
|||
// if it expires.
|
|||
RefreshToken string |
|||
|
|||
// Expiry is the optional expiration time of the access token.
|
|||
//
|
|||
// If zero, TokenSource implementations will reuse the same
|
|||
// token forever and RefreshToken or equivalent
|
|||
// mechanisms for that TokenSource will not be used.
|
|||
Expiry time.Time |
|||
|
|||
// Raw optionally contains extra metadata from the server
|
|||
// when updating a token.
|
|||
Raw interface{} |
|||
} |
|||
|
|||
// tokenJSON is the struct representing the HTTP response from OAuth2
|
|||
// providers returning a token in JSON form.
|
|||
type tokenJSON struct { |
|||
AccessToken string `json:"access_token"` |
|||
TokenType string `json:"token_type"` |
|||
RefreshToken string `json:"refresh_token"` |
|||
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
|
|||
Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in
|
|||
} |
|||
|
|||
func (e *tokenJSON) expiry() (t time.Time) { |
|||
if v := e.ExpiresIn; v != 0 { |
|||
return time.Now().Add(time.Duration(v) * time.Second) |
|||
} |
|||
if v := e.Expires; v != 0 { |
|||
return time.Now().Add(time.Duration(v) * time.Second) |
|||
} |
|||
return |
|||
} |
|||
|
|||
type expirationTime int32 |
|||
|
|||
func (e *expirationTime) UnmarshalJSON(b []byte) error { |
|||
var n json.Number |
|||
err := json.Unmarshal(b, &n) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
i, err := n.Int64() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
*e = expirationTime(i) |
|||
return nil |
|||
} |
|||
|
|||
var brokenAuthHeaderProviders = []string{ |
|||
"https://accounts.google.com/", |
|||
"https://api.dropbox.com/", |
|||
"https://api.dropboxapi.com/", |
|||
"https://api.instagram.com/", |
|||
"https://api.netatmo.net/", |
|||
"https://api.odnoklassniki.ru/", |
|||
"https://api.pushbullet.com/", |
|||
"https://api.soundcloud.com/", |
|||
"https://api.twitch.tv/", |
|||
"https://app.box.com/", |
|||
"https://connect.stripe.com/", |
|||
"https://login.microsoftonline.com/", |
|||
"https://login.salesforce.com/", |
|||
"https://oauth.sandbox.trainingpeaks.com/", |
|||
"https://oauth.trainingpeaks.com/", |
|||
"https://oauth.vk.com/", |
|||
"https://openapi.baidu.com/", |
|||
"https://slack.com/", |
|||
"https://test-sandbox.auth.corp.google.com", |
|||
"https://test.salesforce.com/", |
|||
"https://user.gini.net/", |
|||
"https://www.douban.com/", |
|||
"https://www.googleapis.com/", |
|||
"https://www.linkedin.com/", |
|||
"https://www.strava.com/oauth/", |
|||
"https://www.wunderlist.com/oauth/", |
|||
"https://api.patreon.com/", |
|||
} |
|||
|
|||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { |
|||
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL) |
|||
} |
|||
|
|||
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
|||
// implements the OAuth2 spec correctly
|
|||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
|||
// In summary:
|
|||
// - Reddit only accepts client secret in the Authorization header
|
|||
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
|||
// - Google only accepts URL param (not spec compliant?), not Auth header
|
|||
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
|||
func providerAuthHeaderWorks(tokenURL string) bool { |
|||
for _, s := range brokenAuthHeaderProviders { |
|||
if strings.HasPrefix(tokenURL, s) { |
|||
// Some sites fail to implement the OAuth2 spec fully.
|
|||
return false |
|||
} |
|||
} |
|||
|
|||
// Assume the provider implements the spec properly
|
|||
// otherwise. We can add more exceptions as they're
|
|||
// discovered. We will _not_ be adding configurable hooks
|
|||
// to this package to let users select server bugs.
|
|||
return true |
|||
} |
|||
|
|||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { |
|||
hc, err := ContextClient(ctx) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
v.Set("client_id", clientID) |
|||
bustedAuth := !providerAuthHeaderWorks(tokenURL) |
|||
if bustedAuth && clientSecret != "" { |
|||
v.Set("client_secret", clientSecret) |
|||
} |
|||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
|||
if !bustedAuth { |
|||
req.SetBasicAuth(clientID, clientSecret) |
|||
} |
|||
r, err := hc.Do(req) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
defer r.Body.Close() |
|||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
|||
} |
|||
if code := r.StatusCode; code < 200 || code > 299 { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) |
|||
} |
|||
|
|||
var token *Token |
|||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) |
|||
switch content { |
|||
case "application/x-www-form-urlencoded", "text/plain": |
|||
vals, err := url.ParseQuery(string(body)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
token = &Token{ |
|||
AccessToken: vals.Get("access_token"), |
|||
TokenType: vals.Get("token_type"), |
|||
RefreshToken: vals.Get("refresh_token"), |
|||
Raw: vals, |
|||
} |
|||
e := vals.Get("expires_in") |
|||
if e == "" { |
|||
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
|||
// returns expires_in field in expires. Remove the fallback to expires,
|
|||
// when Facebook fixes their implementation.
|
|||
e = vals.Get("expires") |
|||
} |
|||
expires, _ := strconv.Atoi(e) |
|||
if expires != 0 { |
|||
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) |
|||
} |
|||
default: |
|||
var tj tokenJSON |
|||
if err = json.Unmarshal(body, &tj); err != nil { |
|||
return nil, err |
|||
} |
|||
token = &Token{ |
|||
AccessToken: tj.AccessToken, |
|||
TokenType: tj.TokenType, |
|||
RefreshToken: tj.RefreshToken, |
|||
Expiry: tj.expiry(), |
|||
Raw: make(map[string]interface{}), |
|||
} |
|||
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
|||
} |
|||
// Don't overwrite `RefreshToken` with an empty value
|
|||
// if this was a token refreshing request.
|
|||
if token.RefreshToken == "" { |
|||
token.RefreshToken = v.Get("refresh_token") |
|||
} |
|||
return token, nil |
|||
} |
@ -0,0 +1,36 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package internal contains support packages for oauth2 package.
|
|||
package internal |
|||
|
|||
import ( |
|||
"fmt" |
|||
"testing" |
|||
) |
|||
|
|||
func TestRegisterBrokenAuthHeaderProvider(t *testing.T) { |
|||
RegisterBrokenAuthHeaderProvider("https://aaa.com/") |
|||
tokenURL := "https://aaa.com/token" |
|||
if providerAuthHeaderWorks(tokenURL) { |
|||
t.Errorf("URL: %s is a broken provider", tokenURL) |
|||
} |
|||
} |
|||
|
|||
func Test_providerAuthHeaderWorks(t *testing.T) { |
|||
for _, p := range brokenAuthHeaderProviders { |
|||
if providerAuthHeaderWorks(p) { |
|||
t.Errorf("URL: %s not found in list", p) |
|||
} |
|||
p := fmt.Sprintf("%ssomesuffix", p) |
|||
if providerAuthHeaderWorks(p) { |
|||
t.Errorf("URL: %s not found in list", p) |
|||
} |
|||
} |
|||
p := "https://api.not-in-the-list-example.com/" |
|||
if !providerAuthHeaderWorks(p) { |
|||
t.Errorf("URL: %s found in list", p) |
|||
} |
|||
|
|||
} |
@ -0,0 +1,69 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package internal contains support packages for oauth2 package.
|
|||
package internal |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
|||
// WithValue function to associate an *http.Client value with a context.
|
|||
var HTTPClient ContextKey |
|||
|
|||
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
|||
// an immutable public variable with a unique type. It's immutable
|
|||
// because nobody else can create a ContextKey, being unexported.
|
|||
type ContextKey struct{} |
|||
|
|||
// ContextClientFunc is a func which tries to return an *http.Client
|
|||
// given a Context value. If it returns an error, the search stops
|
|||
// with that error. If it returns (nil, nil), the search continues
|
|||
// down the list of registered funcs.
|
|||
type ContextClientFunc func(context.Context) (*http.Client, error) |
|||
|
|||
var contextClientFuncs []ContextClientFunc |
|||
|
|||
func RegisterContextClientFunc(fn ContextClientFunc) { |
|||
contextClientFuncs = append(contextClientFuncs, fn) |
|||
} |
|||
|
|||
func ContextClient(ctx context.Context) (*http.Client, error) { |
|||
if ctx != nil { |
|||
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { |
|||
return hc, nil |
|||
} |
|||
} |
|||
for _, fn := range contextClientFuncs { |
|||
c, err := fn(ctx) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if c != nil { |
|||
return c, nil |
|||
} |
|||
} |
|||
return http.DefaultClient, nil |
|||
} |
|||
|
|||
func ContextTransport(ctx context.Context) http.RoundTripper { |
|||
hc, err := ContextClient(ctx) |
|||
// This is a rare error case (somebody using nil on App Engine).
|
|||
if err != nil { |
|||
return ErrorTransport{err} |
|||
} |
|||
return hc.Transport |
|||
} |
|||
|
|||
// ErrorTransport returns the specified error on RoundTrip.
|
|||
// This RoundTripper should be used in rare error cases where
|
|||
// error handling can be postponed to response handling time.
|
|||
type ErrorTransport struct{ Err error } |
|||
|
|||
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { |
|||
return nil, t.Err |
|||
} |
@ -0,0 +1,38 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package internal |
|||
|
|||
import ( |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
func TestContextClient(t *testing.T) { |
|||
rc := &http.Client{} |
|||
RegisterContextClientFunc(func(context.Context) (*http.Client, error) { |
|||
return rc, nil |
|||
}) |
|||
|
|||
c := &http.Client{} |
|||
ctx := context.WithValue(context.Background(), HTTPClient, c) |
|||
|
|||
hc, err := ContextClient(ctx) |
|||
if err != nil { |
|||
t.Fatalf("want valid client; got err = %v", err) |
|||
} |
|||
if hc != c { |
|||
t.Fatalf("want context client = %p; got = %p", c, hc) |
|||
} |
|||
|
|||
hc, err = ContextClient(context.TODO()) |
|||
if err != nil { |
|||
t.Fatalf("want valid client; got err = %v", err) |
|||
} |
|||
if hc != rc { |
|||
t.Fatalf("want registered client = %p; got = %p", c, hc) |
|||
} |
|||
} |
@ -0,0 +1,174 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package jws provides encoding and decoding utilities for
|
|||
// signed JWS messages.
|
|||
package jws // import "golang.org/x/oauth2/jws"
|
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/sha256" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// ClaimSet contains information about the JWT signature including the
|
|||
// permissions being requested (scopes), the target of the token, the issuer,
|
|||
// the time the token was issued, and the lifetime of the token.
|
|||
type ClaimSet struct { |
|||
Iss string `json:"iss"` // email address of the client_id of the application making the access token request
|
|||
Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests
|
|||
Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional).
|
|||
Exp int64 `json:"exp"` // the expiration time of the assertion (seconds since Unix epoch)
|
|||
Iat int64 `json:"iat"` // the time the assertion was issued (seconds since Unix epoch)
|
|||
Typ string `json:"typ,omitempty"` // token type (Optional).
|
|||
|
|||
// Email for which the application is requesting delegated access (Optional).
|
|||
Sub string `json:"sub,omitempty"` |
|||
|
|||
// The old name of Sub. Client keeps setting Prn to be
|
|||
// complaint with legacy OAuth 2.0 providers. (Optional)
|
|||
Prn string `json:"prn,omitempty"` |
|||
|
|||
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
|||
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
|
|||
PrivateClaims map[string]interface{} `json:"-"` |
|||
} |
|||
|
|||
func (c *ClaimSet) encode() (string, error) { |
|||
// Reverting time back for machines whose time is not perfectly in sync.
|
|||
// If client machine's time is in the future according
|
|||
// to Google servers, an access token will not be issued.
|
|||
now := time.Now().Add(-10 * time.Second) |
|||
if c.Iat == 0 { |
|||
c.Iat = now.Unix() |
|||
} |
|||
if c.Exp == 0 { |
|||
c.Exp = now.Add(time.Hour).Unix() |
|||
} |
|||
if c.Exp < c.Iat { |
|||
return "", fmt.Errorf("jws: invalid Exp = %v; must be later than Iat = %v", c.Exp, c.Iat) |
|||
} |
|||
|
|||
b, err := json.Marshal(c) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
|
|||
if len(c.PrivateClaims) == 0 { |
|||
return base64.RawURLEncoding.EncodeToString(b), nil |
|||
} |
|||
|
|||
// Marshal private claim set and then append it to b.
|
|||
prv, err := json.Marshal(c.PrivateClaims) |
|||
if err != nil { |
|||
return "", fmt.Errorf("jws: invalid map of private claims %v", c.PrivateClaims) |
|||
} |
|||
|
|||
// Concatenate public and private claim JSON objects.
|
|||
if !bytes.HasSuffix(b, []byte{'}'}) { |
|||
return "", fmt.Errorf("jws: invalid JSON %s", b) |
|||
} |
|||
if !bytes.HasPrefix(prv, []byte{'{'}) { |
|||
return "", fmt.Errorf("jws: invalid JSON %s", prv) |
|||
} |
|||
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
|
|||
b = append(b, prv[1:]...) // Append private claims.
|
|||
return base64.RawURLEncoding.EncodeToString(b), nil |
|||
} |
|||
|
|||
// Header represents the header for the signed JWS payloads.
|
|||
type Header struct { |
|||
// The algorithm used for signature.
|
|||
Algorithm string `json:"alg"` |
|||
|
|||
// Represents the token type.
|
|||
Typ string `json:"typ"` |
|||
|
|||
// The optional hint of which key is being used.
|
|||
KeyID string `json:"kid,omitempty"` |
|||
} |
|||
|
|||
func (h *Header) encode() (string, error) { |
|||
b, err := json.Marshal(h) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return base64.RawURLEncoding.EncodeToString(b), nil |
|||
} |
|||
|
|||
// Decode decodes a claim set from a JWS payload.
|
|||
func Decode(payload string) (*ClaimSet, error) { |
|||
// decode returned id token to get expiry
|
|||
s := strings.Split(payload, ".") |
|||
if len(s) < 2 { |
|||
// TODO(jbd): Provide more context about the error.
|
|||
return nil, errors.New("jws: invalid token received") |
|||
} |
|||
decoded, err := base64.RawURLEncoding.DecodeString(s[1]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
c := &ClaimSet{} |
|||
err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c) |
|||
return c, err |
|||
} |
|||
|
|||
// Signer returns a signature for the given data.
|
|||
type Signer func(data []byte) (sig []byte, err error) |
|||
|
|||
// EncodeWithSigner encodes a header and claim set with the provided signer.
|
|||
func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) { |
|||
head, err := header.encode() |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
cs, err := c.encode() |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
ss := fmt.Sprintf("%s.%s", head, cs) |
|||
sig, err := sg([]byte(ss)) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil |
|||
} |
|||
|
|||
// Encode encodes a signed JWS with provided header and claim set.
|
|||
// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
|
|||
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { |
|||
sg := func(data []byte) (sig []byte, err error) { |
|||
h := sha256.New() |
|||
h.Write(data) |
|||
return rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil)) |
|||
} |
|||
return EncodeWithSigner(header, c, sg) |
|||
} |
|||
|
|||
// Verify tests whether the provided JWT token's signature was produced by the private key
|
|||
// associated with the supplied public key.
|
|||
func Verify(token string, key *rsa.PublicKey) error { |
|||
parts := strings.Split(token, ".") |
|||
if len(parts) != 3 { |
|||
return errors.New("jws: invalid token received, token must have 3 parts") |
|||
} |
|||
|
|||
signedContent := parts[0] + "." + parts[1] |
|||
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
h := sha256.New() |
|||
h.Write([]byte(signedContent)) |
|||
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), []byte(signatureString)) |
|||
} |
@ -0,0 +1,46 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package jws |
|||
|
|||
import ( |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"testing" |
|||
) |
|||
|
|||
func TestSignAndVerify(t *testing.T) { |
|||
header := &Header{ |
|||
Algorithm: "RS256", |
|||
Typ: "JWT", |
|||
} |
|||
payload := &ClaimSet{ |
|||
Iss: "http://google.com/", |
|||
Aud: "", |
|||
Exp: 3610, |
|||
Iat: 10, |
|||
} |
|||
|
|||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
token, err := Encode(header, payload, privateKey) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
err = Verify(token, &privateKey.PublicKey) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
|
|||
func TestVerifyFailsOnMalformedClaim(t *testing.T) { |
|||
err := Verify("abc.def", nil) |
|||
if err == nil { |
|||
t.Error("Improperly formed JWT should fail.") |
|||
} |
|||
} |
@ -0,0 +1,31 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package jwt_test |
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/jwt" |
|||
) |
|||
|
|||
func ExampleJWTConfig() { |
|||
conf := &jwt.Config{ |
|||
Email: "xxx@developer.com", |
|||
// The contents of your RSA private key or your PEM file
|
|||
// that contains a private key.
|
|||
// If you have a p12 file instead, you
|
|||
// can use `openssl` to export the private key into a pem file.
|
|||
//
|
|||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
|||
//
|
|||
// It only supports PEM containers with no passphrase.
|
|||
PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), |
|||
Subject: "user@example.com", |
|||
TokenURL: "https://provider.com/o/oauth2/token", |
|||
} |
|||
// Initiate an http.Client, the following GET request will be
|
|||
// authorized and authenticated on the behalf of user@example.com.
|
|||
client := conf.Client(oauth2.NoContext) |
|||
client.Get("...") |
|||
} |
@ -0,0 +1,157 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package jwt implements the OAuth 2.0 JSON Web Token flow, commonly
|
|||
// known as "two-legged OAuth 2.0".
|
|||
//
|
|||
// See: https://tools.ietf.org/html/draft-ietf-oauth-jwt-bearer-12
|
|||
package jwt |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/internal" |
|||
"golang.org/x/oauth2/jws" |
|||
) |
|||
|
|||
var ( |
|||
defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" |
|||
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} |
|||
) |
|||
|
|||
// Config is the configuration for using JWT to fetch tokens,
|
|||
// commonly known as "two-legged OAuth 2.0".
|
|||
type Config struct { |
|||
// Email is the OAuth client identifier used when communicating with
|
|||
// the configured OAuth provider.
|
|||
Email string |
|||
|
|||
// PrivateKey contains the contents of an RSA private key or the
|
|||
// contents of a PEM file that contains a private key. The provided
|
|||
// private key is used to sign JWT payloads.
|
|||
// PEM containers with a passphrase are not supported.
|
|||
// Use the following command to convert a PKCS 12 file into a PEM.
|
|||
//
|
|||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
|||
//
|
|||
PrivateKey []byte |
|||
|
|||
// PrivateKeyID contains an optional hint indicating which key is being
|
|||
// used.
|
|||
PrivateKeyID string |
|||
|
|||
// Subject is the optional user to impersonate.
|
|||
Subject string |
|||
|
|||
// Scopes optionally specifies a list of requested permission scopes.
|
|||
Scopes []string |
|||
|
|||
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
|
|||
TokenURL string |
|||
|
|||
// Expires optionally specifies how long the token is valid for.
|
|||
Expires time.Duration |
|||
} |
|||
|
|||
// TokenSource returns a JWT TokenSource using the configuration
|
|||
// in c and the HTTP client from the provided context.
|
|||
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { |
|||
return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c}) |
|||
} |
|||
|
|||
// Client returns an HTTP client wrapping the context's
|
|||
// HTTP transport and adding Authorization headers with tokens
|
|||
// obtained from c.
|
|||
//
|
|||
// The returned client and its Transport should not be modified.
|
|||
func (c *Config) Client(ctx context.Context) *http.Client { |
|||
return oauth2.NewClient(ctx, c.TokenSource(ctx)) |
|||
} |
|||
|
|||
// jwtSource is a source that always does a signed JWT request for a token.
|
|||
// It should typically be wrapped with a reuseTokenSource.
|
|||
type jwtSource struct { |
|||
ctx context.Context |
|||
conf *Config |
|||
} |
|||
|
|||
func (js jwtSource) Token() (*oauth2.Token, error) { |
|||
pk, err := internal.ParseKey(js.conf.PrivateKey) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
hc := oauth2.NewClient(js.ctx, nil) |
|||
claimSet := &jws.ClaimSet{ |
|||
Iss: js.conf.Email, |
|||
Scope: strings.Join(js.conf.Scopes, " "), |
|||
Aud: js.conf.TokenURL, |
|||
} |
|||
if subject := js.conf.Subject; subject != "" { |
|||
claimSet.Sub = subject |
|||
// prn is the old name of sub. Keep setting it
|
|||
// to be compatible with legacy OAuth 2.0 providers.
|
|||
claimSet.Prn = subject |
|||
} |
|||
if t := js.conf.Expires; t > 0 { |
|||
claimSet.Exp = time.Now().Add(t).Unix() |
|||
} |
|||
payload, err := jws.Encode(defaultHeader, claimSet, pk) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
v := url.Values{} |
|||
v.Set("grant_type", defaultGrantType) |
|||
v.Set("assertion", payload) |
|||
resp, err := hc.PostForm(js.conf.TokenURL, v) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
|||
} |
|||
defer resp.Body.Close() |
|||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
|||
} |
|||
if c := resp.StatusCode; c < 200 || c > 299 { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) |
|||
} |
|||
// tokenRes is the JSON response body.
|
|||
var tokenRes struct { |
|||
AccessToken string `json:"access_token"` |
|||
TokenType string `json:"token_type"` |
|||
IDToken string `json:"id_token"` |
|||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
|||
} |
|||
if err := json.Unmarshal(body, &tokenRes); err != nil { |
|||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
|||
} |
|||
token := &oauth2.Token{ |
|||
AccessToken: tokenRes.AccessToken, |
|||
TokenType: tokenRes.TokenType, |
|||
} |
|||
raw := make(map[string]interface{}) |
|||
json.Unmarshal(body, &raw) // no error checks for optional fields
|
|||
token = token.WithExtra(raw) |
|||
|
|||
if secs := tokenRes.ExpiresIn; secs > 0 { |
|||
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) |
|||
} |
|||
if v := tokenRes.IDToken; v != "" { |
|||
// decode returned id token to get expiry
|
|||
claimSet, err := jws.Decode(v) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err) |
|||
} |
|||
token.Expiry = time.Unix(claimSet.Exp, 0) |
|||
} |
|||
return token, nil |
|||
} |
@ -0,0 +1,134 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package jwt |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
|
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- |
|||
MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE |
|||
DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY |
|||
fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK |
|||
1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr |
|||
k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9 |
|||
/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt |
|||
3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn |
|||
2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3 |
|||
nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK |
|||
6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf |
|||
5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e |
|||
DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1 |
|||
M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g |
|||
z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y |
|||
1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK |
|||
J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U |
|||
f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx |
|||
QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA |
|||
cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr |
|||
Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw |
|||
5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg |
|||
KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84 |
|||
OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd |
|||
mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ |
|||
5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg== |
|||
-----END RSA PRIVATE KEY-----`) |
|||
|
|||
func TestJWTFetch_JSONResponse(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{ |
|||
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", |
|||
"scope": "user", |
|||
"token_type": "bearer", |
|||
"expires_in": 3600 |
|||
}`)) |
|||
})) |
|||
defer ts.Close() |
|||
|
|||
conf := &Config{ |
|||
Email: "aaa@xxx.com", |
|||
PrivateKey: dummyPrivateKey, |
|||
TokenURL: ts.URL, |
|||
} |
|||
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if !tok.Valid() { |
|||
t.Errorf("Token invalid") |
|||
} |
|||
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
|||
t.Errorf("Unexpected access token, %#v", tok.AccessToken) |
|||
} |
|||
if tok.TokenType != "bearer" { |
|||
t.Errorf("Unexpected token type, %#v", tok.TokenType) |
|||
} |
|||
if tok.Expiry.IsZero() { |
|||
t.Errorf("Unexpected token expiry, %#v", tok.Expiry) |
|||
} |
|||
scope := tok.Extra("scope") |
|||
if scope != "user" { |
|||
t.Errorf("Unexpected value for scope: %v", scope) |
|||
} |
|||
} |
|||
|
|||
func TestJWTFetch_BadResponse(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) |
|||
})) |
|||
defer ts.Close() |
|||
|
|||
conf := &Config{ |
|||
Email: "aaa@xxx.com", |
|||
PrivateKey: dummyPrivateKey, |
|||
TokenURL: ts.URL, |
|||
} |
|||
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if tok == nil { |
|||
t.Fatalf("token is nil") |
|||
} |
|||
if tok.Valid() { |
|||
t.Errorf("token is valid. want invalid.") |
|||
} |
|||
if tok.AccessToken != "" { |
|||
t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken) |
|||
} |
|||
if want := "bearer"; tok.TokenType != want { |
|||
t.Errorf("TokenType = %q; want %q", tok.TokenType, want) |
|||
} |
|||
scope := tok.Extra("scope") |
|||
if want := "user"; scope != want { |
|||
t.Errorf("token scope = %q; want %q", scope, want) |
|||
} |
|||
} |
|||
|
|||
func TestJWTFetch_BadResponseType(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) |
|||
})) |
|||
defer ts.Close() |
|||
conf := &Config{ |
|||
Email: "aaa@xxx.com", |
|||
PrivateKey: dummyPrivateKey, |
|||
TokenURL: ts.URL, |
|||
} |
|||
tok, err := conf.TokenSource(oauth2.NoContext).Token() |
|||
if err == nil { |
|||
t.Error("got a token; expected error") |
|||
if tok.AccessToken != "" { |
|||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package linkedin provides constants for using OAuth2 to access LinkedIn.
|
|||
package linkedin // import "golang.org/x/oauth2/linkedin"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is LinkedIn's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.linkedin.com/uas/oauth2/authorization", |
|||
TokenURL: "https://www.linkedin.com/uas/oauth2/accessToken", |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package microsoft provides constants for using OAuth2 to access Windows Live ID.
|
|||
package microsoft // import "golang.org/x/oauth2/microsoft"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// LiveConnectEndpoint is Windows's Live ID OAuth 2.0 endpoint.
|
|||
var LiveConnectEndpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://login.live.com/oauth20_authorize.srf", |
|||
TokenURL: "https://login.live.com/oauth20_token.srf", |
|||
} |
@ -0,0 +1,337 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package oauth2 provides support for making
|
|||
// OAuth2 authorized and authenticated HTTP requests.
|
|||
// It can additionally grant authorization with Bearer JWT.
|
|||
package oauth2 // import "golang.org/x/oauth2"
|
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2/internal" |
|||
) |
|||
|
|||
// NoContext is the default context you should supply if not using
|
|||
// your own context.Context (see https://golang.org/x/net/context).
|
|||
var NoContext = context.TODO() |
|||
|
|||
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
|
|||
// identified by the tokenURL prefix as an OAuth2 implementation
|
|||
// which doesn't support the HTTP Basic authentication
|
|||
// scheme to authenticate with the authorization server.
|
|||
// Once a server is registered, credentials (client_id and client_secret)
|
|||
// will be passed as query parameters rather than being present
|
|||
// in the Authorization header.
|
|||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
|||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { |
|||
internal.RegisterBrokenAuthHeaderProvider(tokenURL) |
|||
} |
|||
|
|||
// Config describes a typical 3-legged OAuth2 flow, with both the
|
|||
// client application information and the server's endpoint URLs.
|
|||
type Config struct { |
|||
// ClientID is the application's ID.
|
|||
ClientID string |
|||
|
|||
// ClientSecret is the application's secret.
|
|||
ClientSecret string |
|||
|
|||
// Endpoint contains the resource server's token endpoint
|
|||
// URLs. These are constants specific to each server and are
|
|||
// often available via site-specific packages, such as
|
|||
// google.Endpoint or github.Endpoint.
|
|||
Endpoint Endpoint |
|||
|
|||
// RedirectURL is the URL to redirect users going through
|
|||
// the OAuth flow, after the resource owner's URLs.
|
|||
RedirectURL string |
|||
|
|||
// Scope specifies optional requested permissions.
|
|||
Scopes []string |
|||
} |
|||
|
|||
// A TokenSource is anything that can return a token.
|
|||
type TokenSource interface { |
|||
// Token returns a token or an error.
|
|||
// Token must be safe for concurrent use by multiple goroutines.
|
|||
// The returned Token must not be modified.
|
|||
Token() (*Token, error) |
|||
} |
|||
|
|||
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
|||
// endpoint URLs.
|
|||
type Endpoint struct { |
|||
AuthURL string |
|||
TokenURL string |
|||
} |
|||
|
|||
var ( |
|||
// AccessTypeOnline and AccessTypeOffline are options passed
|
|||
// to the Options.AuthCodeURL method. They modify the
|
|||
// "access_type" field that gets sent in the URL returned by
|
|||
// AuthCodeURL.
|
|||
//
|
|||
// Online is the default if neither is specified. If your
|
|||
// application needs to refresh access tokens when the user
|
|||
// is not present at the browser, then use offline. This will
|
|||
// result in your application obtaining a refresh token the
|
|||
// first time your application exchanges an authorization
|
|||
// code for a user.
|
|||
AccessTypeOnline AuthCodeOption = SetAuthURLParam("access_type", "online") |
|||
AccessTypeOffline AuthCodeOption = SetAuthURLParam("access_type", "offline") |
|||
|
|||
// ApprovalForce forces the users to view the consent dialog
|
|||
// and confirm the permissions request at the URL returned
|
|||
// from AuthCodeURL, even if they've already done so.
|
|||
ApprovalForce AuthCodeOption = SetAuthURLParam("approval_prompt", "force") |
|||
) |
|||
|
|||
// An AuthCodeOption is passed to Config.AuthCodeURL.
|
|||
type AuthCodeOption interface { |
|||
setValue(url.Values) |
|||
} |
|||
|
|||
type setParam struct{ k, v string } |
|||
|
|||
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } |
|||
|
|||
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
|
|||
// to a provider's authorization endpoint.
|
|||
func SetAuthURLParam(key, value string) AuthCodeOption { |
|||
return setParam{key, value} |
|||
} |
|||
|
|||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
|||
// that asks for permissions for the required scopes explicitly.
|
|||
//
|
|||
// State is a token to protect the user from CSRF attacks. You must
|
|||
// always provide a non-zero string and validate that it matches the
|
|||
// the state query parameter on your redirect callback.
|
|||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
|||
//
|
|||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
|||
// as ApprovalForce.
|
|||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { |
|||
var buf bytes.Buffer |
|||
buf.WriteString(c.Endpoint.AuthURL) |
|||
v := url.Values{ |
|||
"response_type": {"code"}, |
|||
"client_id": {c.ClientID}, |
|||
"redirect_uri": internal.CondVal(c.RedirectURL), |
|||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
|||
"state": internal.CondVal(state), |
|||
} |
|||
for _, opt := range opts { |
|||
opt.setValue(v) |
|||
} |
|||
if strings.Contains(c.Endpoint.AuthURL, "?") { |
|||
buf.WriteByte('&') |
|||
} else { |
|||
buf.WriteByte('?') |
|||
} |
|||
buf.WriteString(v.Encode()) |
|||
return buf.String() |
|||
} |
|||
|
|||
// PasswordCredentialsToken converts a resource owner username and password
|
|||
// pair into a token.
|
|||
//
|
|||
// Per the RFC, this grant type should only be used "when there is a high
|
|||
// degree of trust between the resource owner and the client (e.g., the client
|
|||
// is part of the device operating system or a highly privileged application),
|
|||
// and when other authorization grant types are not available."
|
|||
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
|
|||
//
|
|||
// The HTTP client to use is derived from the context.
|
|||
// If nil, http.DefaultClient is used.
|
|||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { |
|||
return retrieveToken(ctx, c, url.Values{ |
|||
"grant_type": {"password"}, |
|||
"username": {username}, |
|||
"password": {password}, |
|||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
|||
}) |
|||
} |
|||
|
|||
// Exchange converts an authorization code into a token.
|
|||
//
|
|||
// It is used after a resource provider redirects the user back
|
|||
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
|||
//
|
|||
// The HTTP client to use is derived from the context.
|
|||
// If a client is not provided via the context, http.DefaultClient is used.
|
|||
//
|
|||
// The code will be in the *http.Request.FormValue("code"). Before
|
|||
// calling Exchange, be sure to validate FormValue("state").
|
|||
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { |
|||
return retrieveToken(ctx, c, url.Values{ |
|||
"grant_type": {"authorization_code"}, |
|||
"code": {code}, |
|||
"redirect_uri": internal.CondVal(c.RedirectURL), |
|||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), |
|||
}) |
|||
} |
|||
|
|||
// Client returns an HTTP client using the provided token.
|
|||
// The token will auto-refresh as necessary. The underlying
|
|||
// HTTP transport will be obtained using the provided context.
|
|||
// The returned client and its Transport should not be modified.
|
|||
func (c *Config) Client(ctx context.Context, t *Token) *http.Client { |
|||
return NewClient(ctx, c.TokenSource(ctx, t)) |
|||
} |
|||
|
|||
// TokenSource returns a TokenSource that returns t until t expires,
|
|||
// automatically refreshing it as necessary using the provided context.
|
|||
//
|
|||
// Most users will use Config.Client instead.
|
|||
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { |
|||
tkr := &tokenRefresher{ |
|||
ctx: ctx, |
|||
conf: c, |
|||
} |
|||
if t != nil { |
|||
tkr.refreshToken = t.RefreshToken |
|||
} |
|||
return &reuseTokenSource{ |
|||
t: t, |
|||
new: tkr, |
|||
} |
|||
} |
|||
|
|||
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
|||
// HTTP requests to renew a token using a RefreshToken.
|
|||
type tokenRefresher struct { |
|||
ctx context.Context // used to get HTTP requests
|
|||
conf *Config |
|||
refreshToken string |
|||
} |
|||
|
|||
// WARNING: Token is not safe for concurrent access, as it
|
|||
// updates the tokenRefresher's refreshToken field.
|
|||
// Within this package, it is used by reuseTokenSource which
|
|||
// synchronizes calls to this method with its own mutex.
|
|||
func (tf *tokenRefresher) Token() (*Token, error) { |
|||
if tf.refreshToken == "" { |
|||
return nil, errors.New("oauth2: token expired and refresh token is not set") |
|||
} |
|||
|
|||
tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ |
|||
"grant_type": {"refresh_token"}, |
|||
"refresh_token": {tf.refreshToken}, |
|||
}) |
|||
|
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if tf.refreshToken != tk.RefreshToken { |
|||
tf.refreshToken = tk.RefreshToken |
|||
} |
|||
return tk, err |
|||
} |
|||
|
|||
// reuseTokenSource is a TokenSource that holds a single token in memory
|
|||
// and validates its expiry before each call to retrieve it with
|
|||
// Token. If it's expired, it will be auto-refreshed using the
|
|||
// new TokenSource.
|
|||
type reuseTokenSource struct { |
|||
new TokenSource // called when t is expired.
|
|||
|
|||
mu sync.Mutex // guards t
|
|||
t *Token |
|||
} |
|||
|
|||
// Token returns the current token if it's still valid, else will
|
|||
// refresh the current token (using r.Context for HTTP client
|
|||
// information) and return the new one.
|
|||
func (s *reuseTokenSource) Token() (*Token, error) { |
|||
s.mu.Lock() |
|||
defer s.mu.Unlock() |
|||
if s.t.Valid() { |
|||
return s.t, nil |
|||
} |
|||
t, err := s.new.Token() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
s.t = t |
|||
return t, nil |
|||
} |
|||
|
|||
// StaticTokenSource returns a TokenSource that always returns the same token.
|
|||
// Because the provided token t is never refreshed, StaticTokenSource is only
|
|||
// useful for tokens that never expire.
|
|||
func StaticTokenSource(t *Token) TokenSource { |
|||
return staticTokenSource{t} |
|||
} |
|||
|
|||
// staticTokenSource is a TokenSource that always returns the same Token.
|
|||
type staticTokenSource struct { |
|||
t *Token |
|||
} |
|||
|
|||
func (s staticTokenSource) Token() (*Token, error) { |
|||
return s.t, nil |
|||
} |
|||
|
|||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
|||
// WithValue function to associate an *http.Client value with a context.
|
|||
var HTTPClient internal.ContextKey |
|||
|
|||
// NewClient creates an *http.Client from a Context and TokenSource.
|
|||
// The returned client is not valid beyond the lifetime of the context.
|
|||
//
|
|||
// As a special case, if src is nil, a non-OAuth2 client is returned
|
|||
// using the provided context. This exists to support related OAuth2
|
|||
// packages.
|
|||
func NewClient(ctx context.Context, src TokenSource) *http.Client { |
|||
if src == nil { |
|||
c, err := internal.ContextClient(ctx) |
|||
if err != nil { |
|||
return &http.Client{Transport: internal.ErrorTransport{Err: err}} |
|||
} |
|||
return c |
|||
} |
|||
return &http.Client{ |
|||
Transport: &Transport{ |
|||
Base: internal.ContextTransport(ctx), |
|||
Source: ReuseTokenSource(nil, src), |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// ReuseTokenSource returns a TokenSource which repeatedly returns the
|
|||
// same token as long as it's valid, starting with t.
|
|||
// When its cached token is invalid, a new token is obtained from src.
|
|||
//
|
|||
// ReuseTokenSource is typically used to reuse tokens from a cache
|
|||
// (such as a file on disk) between runs of a program, rather than
|
|||
// obtaining new tokens unnecessarily.
|
|||
//
|
|||
// The initial token t may be nil, in which case the TokenSource is
|
|||
// wrapped in a caching version if it isn't one already. This also
|
|||
// means it's always safe to wrap ReuseTokenSource around any other
|
|||
// TokenSource without adverse effects.
|
|||
func ReuseTokenSource(t *Token, src TokenSource) TokenSource { |
|||
// Don't wrap a reuseTokenSource in itself. That would work,
|
|||
// but cause an unnecessary number of mutex operations.
|
|||
// Just build the equivalent one.
|
|||
if rt, ok := src.(*reuseTokenSource); ok { |
|||
if t == nil { |
|||
// Just use it directly.
|
|||
return rt |
|||
} |
|||
src = rt.new |
|||
} |
|||
return &reuseTokenSource{ |
|||
t: t, |
|||
new: src, |
|||
} |
|||
} |
@ -0,0 +1,458 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package oauth2 |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"errors" |
|||
"fmt" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"net/url" |
|||
"reflect" |
|||
"strconv" |
|||
"testing" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
type mockTransport struct { |
|||
rt func(req *http.Request) (resp *http.Response, err error) |
|||
} |
|||
|
|||
func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { |
|||
return t.rt(req) |
|||
} |
|||
|
|||
func newConf(url string) *Config { |
|||
return &Config{ |
|||
ClientID: "CLIENT_ID", |
|||
ClientSecret: "CLIENT_SECRET", |
|||
RedirectURL: "REDIRECT_URL", |
|||
Scopes: []string{"scope1", "scope2"}, |
|||
Endpoint: Endpoint{ |
|||
AuthURL: url + "/auth", |
|||
TokenURL: url + "/token", |
|||
}, |
|||
} |
|||
} |
|||
|
|||
func TestAuthCodeURL(t *testing.T) { |
|||
conf := newConf("server") |
|||
url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce) |
|||
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { |
|||
t.Errorf("Auth code URL doesn't match the expected, found: %v", url) |
|||
} |
|||
} |
|||
|
|||
func TestAuthCodeURL_CustomParam(t *testing.T) { |
|||
conf := newConf("server") |
|||
param := SetAuthURLParam("foo", "bar") |
|||
url := conf.AuthCodeURL("baz", param) |
|||
if url != "server/auth?client_id=CLIENT_ID&foo=bar&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=baz" { |
|||
t.Errorf("Auth code URL doesn't match the expected, found: %v", url) |
|||
} |
|||
} |
|||
|
|||
func TestAuthCodeURL_Optional(t *testing.T) { |
|||
conf := &Config{ |
|||
ClientID: "CLIENT_ID", |
|||
Endpoint: Endpoint{ |
|||
AuthURL: "/auth-url", |
|||
TokenURL: "/token-url", |
|||
}, |
|||
} |
|||
url := conf.AuthCodeURL("") |
|||
if url != "/auth-url?client_id=CLIENT_ID&response_type=code" { |
|||
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) |
|||
} |
|||
} |
|||
|
|||
func TestExchangeRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
|||
} |
|||
headerAuth := r.Header.Get("Authorization") |
|||
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
|||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
if headerContentType != "application/x-www-form-urlencoded" { |
|||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
|||
} |
|||
body, err := ioutil.ReadAll(r.Body) |
|||
if err != nil { |
|||
t.Errorf("Failed reading request body: %s.", err) |
|||
} |
|||
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { |
|||
t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
|||
} |
|||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
|||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tok, err := conf.Exchange(NoContext, "exchange-code") |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if !tok.Valid() { |
|||
t.Fatalf("Token invalid. Got: %#v", tok) |
|||
} |
|||
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
|||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
|||
} |
|||
if tok.TokenType != "bearer" { |
|||
t.Errorf("Unexpected token type, %#v.", tok.TokenType) |
|||
} |
|||
scope := tok.Extra("scope") |
|||
if scope != "user" { |
|||
t.Errorf("Unexpected value for scope: %v", scope) |
|||
} |
|||
} |
|||
|
|||
func TestExchangeRequest_JSONResponse(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
|||
} |
|||
headerAuth := r.Header.Get("Authorization") |
|||
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
|||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
if headerContentType != "application/x-www-form-urlencoded" { |
|||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
|||
} |
|||
body, err := ioutil.ReadAll(r.Body) |
|||
if err != nil { |
|||
t.Errorf("Failed reading request body: %s.", err) |
|||
} |
|||
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { |
|||
t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
|||
} |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tok, err := conf.Exchange(NoContext, "exchange-code") |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if !tok.Valid() { |
|||
t.Fatalf("Token invalid. Got: %#v", tok) |
|||
} |
|||
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { |
|||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
|||
} |
|||
if tok.TokenType != "bearer" { |
|||
t.Errorf("Unexpected token type, %#v.", tok.TokenType) |
|||
} |
|||
scope := tok.Extra("scope") |
|||
if scope != "user" { |
|||
t.Errorf("Unexpected value for scope: %v", scope) |
|||
} |
|||
expiresIn := tok.Extra("expires_in") |
|||
if expiresIn != float64(86400) { |
|||
t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) |
|||
} |
|||
} |
|||
|
|||
func TestExtraValueRetrieval(t *testing.T) { |
|||
values := url.Values{} |
|||
|
|||
kvmap := map[string]string{ |
|||
"scope": "user", "token_type": "bearer", "expires_in": "86400.92", |
|||
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1", |
|||
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400", |
|||
"untrimmed": " untrimmed ", |
|||
} |
|||
|
|||
for key, value := range kvmap { |
|||
values.Set(key, value) |
|||
} |
|||
|
|||
tok := Token{ |
|||
raw: values, |
|||
} |
|||
|
|||
scope := tok.Extra("scope") |
|||
if scope != "user" { |
|||
t.Errorf("Unexpected scope %v wanted \"user\"", scope) |
|||
} |
|||
serverTime := tok.Extra("server_time") |
|||
if serverTime != 1443571905.5606415 { |
|||
t.Errorf("Unexpected non-float64 value for server_time: %v", serverTime) |
|||
} |
|||
refererIp := tok.Extra("referer_ip") |
|||
if refererIp != "10.0.0.1" { |
|||
t.Errorf("Unexpected non-string value for referer_ip: %v", refererIp) |
|||
} |
|||
expires_in := tok.Extra("expires_in") |
|||
if expires_in != 86400.92 { |
|||
t.Errorf("Unexpected value for expires_in, wanted 86400 got %v", expires_in) |
|||
} |
|||
requestId := tok.Extra("request_id") |
|||
if requestId != int64(86400) { |
|||
t.Errorf("Unexpected non-int64 value for request_id: %v", requestId) |
|||
} |
|||
untrimmed := tok.Extra("untrimmed") |
|||
if untrimmed != " untrimmed " { |
|||
t.Errorf("Unexpected value for untrimmed, got %q expected \" untrimmed \"", untrimmed) |
|||
} |
|||
} |
|||
|
|||
const day = 24 * time.Hour |
|||
|
|||
func TestExchangeRequest_JSONResponse_Expiry(t *testing.T) { |
|||
seconds := int32(day.Seconds()) |
|||
jsonNumberType := reflect.TypeOf(json.Number("0")) |
|||
for _, c := range []struct { |
|||
expires string |
|||
expect error |
|||
}{ |
|||
{fmt.Sprintf(`"expires_in": %d`, seconds), nil}, |
|||
{fmt.Sprintf(`"expires_in": "%d"`, seconds), nil}, // PayPal case
|
|||
{fmt.Sprintf(`"expires": %d`, seconds), nil}, // Facebook case
|
|||
{`"expires": false`, &json.UnmarshalTypeError{Value: "bool", Type: jsonNumberType}}, // wrong type
|
|||
{`"expires": {}`, &json.UnmarshalTypeError{Value: "object", Type: jsonNumberType}}, // wrong type
|
|||
{`"expires": "zzz"`, &strconv.NumError{Func: "ParseInt", Num: "zzz", Err: strconv.ErrSyntax}}, // wrong value
|
|||
} { |
|||
testExchangeRequest_JSONResponse_expiry(t, c.expires, c.expect) |
|||
} |
|||
} |
|||
|
|||
func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, expect error) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
t1 := time.Now().Add(day) |
|||
tok, err := conf.Exchange(NoContext, "exchange-code") |
|||
t2 := time.Now().Add(day) |
|||
// Do a fmt.Sprint comparison so either side can be
|
|||
// nil. fmt.Sprint just stringifies them to "<nil>", and no
|
|||
// non-nil expected error ever stringifies as "<nil>", so this
|
|||
// isn't terribly disgusting. We do this because Go 1.4 and
|
|||
// Go 1.5 return a different deep value for
|
|||
// json.UnmarshalTypeError. In Go 1.5, the
|
|||
// json.UnmarshalTypeError contains a new field with a new
|
|||
// non-zero value. Rather than ignore it here with reflect or
|
|||
// add new files and +build tags, just look at the strings.
|
|||
if fmt.Sprint(err) != fmt.Sprint(expect) { |
|||
t.Errorf("Error = %v; want %v", err, expect) |
|||
} |
|||
if err != nil { |
|||
return |
|||
} |
|||
if !tok.Valid() { |
|||
t.Fatalf("Token invalid. Got: %#v", tok) |
|||
} |
|||
expiry := tok.Expiry |
|||
if expiry.Before(t1) || expiry.After(t2) { |
|||
t.Errorf("Unexpected value for Expiry: %v (shold be between %v and %v)", expiry, t1, t2) |
|||
} |
|||
} |
|||
|
|||
func TestExchangeRequest_BadResponse(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tok, err := conf.Exchange(NoContext, "code") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
if tok.AccessToken != "" { |
|||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) |
|||
} |
|||
} |
|||
|
|||
func TestExchangeRequest_BadResponseType(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
_, err := conf.Exchange(NoContext, "exchange-code") |
|||
if err == nil { |
|||
t.Error("expected error from invalid access_token type") |
|||
} |
|||
} |
|||
|
|||
func TestExchangeRequest_NonBasicAuth(t *testing.T) { |
|||
tr := &mockTransport{ |
|||
rt: func(r *http.Request) (w *http.Response, err error) { |
|||
headerAuth := r.Header.Get("Authorization") |
|||
if headerAuth != "" { |
|||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
|||
} |
|||
return nil, errors.New("no response") |
|||
}, |
|||
} |
|||
c := &http.Client{Transport: tr} |
|||
conf := &Config{ |
|||
ClientID: "CLIENT_ID", |
|||
Endpoint: Endpoint{ |
|||
AuthURL: "https://accounts.google.com/auth", |
|||
TokenURL: "https://accounts.google.com/token", |
|||
}, |
|||
} |
|||
|
|||
ctx := context.WithValue(context.Background(), HTTPClient, c) |
|||
conf.Exchange(ctx, "code") |
|||
} |
|||
|
|||
func TestPasswordCredentialsTokenRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
defer r.Body.Close() |
|||
expected := "/token" |
|||
if r.URL.String() != expected { |
|||
t.Errorf("URL = %q; want %q", r.URL, expected) |
|||
} |
|||
headerAuth := r.Header.Get("Authorization") |
|||
expected = "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" |
|||
if headerAuth != expected { |
|||
t.Errorf("Authorization header = %q; want %q", headerAuth, expected) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
expected = "application/x-www-form-urlencoded" |
|||
if headerContentType != expected { |
|||
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected) |
|||
} |
|||
body, err := ioutil.ReadAll(r.Body) |
|||
if err != nil { |
|||
t.Errorf("Failed reading request body: %s.", err) |
|||
} |
|||
expected = "client_id=CLIENT_ID&grant_type=password&password=password1&scope=scope1+scope2&username=user1" |
|||
if string(body) != expected { |
|||
t.Errorf("res.Body = %q; want %q", string(body), expected) |
|||
} |
|||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
|||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tok, err := conf.PasswordCredentialsToken(NoContext, "user1", "password1") |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
if !tok.Valid() { |
|||
t.Fatalf("Token invalid. Got: %#v", tok) |
|||
} |
|||
expected := "90d64460d14870c08c81352a05dedd3465940a7c" |
|||
if tok.AccessToken != expected { |
|||
t.Errorf("AccessToken = %q; want %q", tok.AccessToken, expected) |
|||
} |
|||
expected = "bearer" |
|||
if tok.TokenType != expected { |
|||
t.Errorf("TokenType = %q; want %q", tok.TokenType, expected) |
|||
} |
|||
} |
|||
|
|||
func TestTokenRefreshRequest(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() == "/somethingelse" { |
|||
return |
|||
} |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
if headerContentType != "application/x-www-form-urlencoded" { |
|||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
|||
} |
|||
body, _ := ioutil.ReadAll(r.Body) |
|||
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { |
|||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
|||
} |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
c := conf.Client(NoContext, &Token{RefreshToken: "REFRESH_TOKEN"}) |
|||
c.Get(ts.URL + "/somethingelse") |
|||
} |
|||
|
|||
func TestFetchWithNoRefreshToken(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.URL.String() == "/somethingelse" { |
|||
return |
|||
} |
|||
if r.URL.String() != "/token" { |
|||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) |
|||
} |
|||
headerContentType := r.Header.Get("Content-Type") |
|||
if headerContentType != "application/x-www-form-urlencoded" { |
|||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
|||
} |
|||
body, _ := ioutil.ReadAll(r.Body) |
|||
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { |
|||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) |
|||
} |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
c := conf.Client(NoContext, nil) |
|||
_, err := c.Get(ts.URL + "/somethingelse") |
|||
if err == nil { |
|||
t.Errorf("Fetch should return an error if no refresh token is set") |
|||
} |
|||
} |
|||
|
|||
func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Write([]byte(`{"access_token":"ACCESS TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW REFRESH TOKEN"}`)) |
|||
return |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
tkr := tokenRefresher{ |
|||
conf: conf, |
|||
ctx: NoContext, |
|||
refreshToken: "OLD REFRESH TOKEN", |
|||
} |
|||
tk, err := tkr.Token() |
|||
if err != nil { |
|||
t.Errorf("Unexpected refreshToken error returned: %v", err) |
|||
return |
|||
} |
|||
if tk.RefreshToken != tkr.refreshToken { |
|||
t.Errorf("tokenRefresher.refresh_token = %s; want %s", tkr.refreshToken, tk.RefreshToken) |
|||
} |
|||
} |
|||
|
|||
func TestConfigClientWithToken(t *testing.T) { |
|||
tok := &Token{ |
|||
AccessToken: "abc123", |
|||
} |
|||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
if got, want := r.Header.Get("Authorization"), fmt.Sprintf("Bearer %s", tok.AccessToken); got != want { |
|||
t.Errorf("Authorization header = %q; want %q", got, want) |
|||
} |
|||
return |
|||
})) |
|||
defer ts.Close() |
|||
conf := newConf(ts.URL) |
|||
|
|||
c := conf.Client(NoContext, tok) |
|||
req, err := http.NewRequest("GET", ts.URL, nil) |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
_, err = c.Do(req) |
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package odnoklassniki provides constants for using OAuth2 to access Odnoklassniki.
|
|||
package odnoklassniki // import "golang.org/x/oauth2/odnoklassniki"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is Odnoklassniki's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.odnoklassniki.ru/oauth/authorize", |
|||
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do", |
|||
} |
@ -0,0 +1,22 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package paypal provides constants for using OAuth2 to access PayPal.
|
|||
package paypal // import "golang.org/x/oauth2/paypal"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is PayPal's OAuth 2.0 endpoint in live (production) environment.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize", |
|||
TokenURL: "https://api.paypal.com/v1/identity/openidconnect/tokenservice", |
|||
} |
|||
|
|||
// SandboxEndpoint is PayPal's OAuth 2.0 endpoint in sandbox (testing) environment.
|
|||
var SandboxEndpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://www.sandbox.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize", |
|||
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice", |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package slack provides constants for using OAuth2 to access Slack.
|
|||
package slack // import "golang.org/x/oauth2/slack"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is Slack's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://slack.com/oauth/authorize", |
|||
TokenURL: "https://slack.com/api/oauth.access", |
|||
} |
@ -0,0 +1,158 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package oauth2 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/url" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2/internal" |
|||
) |
|||
|
|||
// expiryDelta determines how earlier a token should be considered
|
|||
// expired than its actual expiration time. It is used to avoid late
|
|||
// expirations due to client-server time mismatches.
|
|||
const expiryDelta = 10 * time.Second |
|||
|
|||
// Token represents the crendentials used to authorize
|
|||
// the requests to access protected resources on the OAuth 2.0
|
|||
// provider's backend.
|
|||
//
|
|||
// Most users of this package should not access fields of Token
|
|||
// directly. They're exported mostly for use by related packages
|
|||
// implementing derivative OAuth2 flows.
|
|||
type Token struct { |
|||
// AccessToken is the token that authorizes and authenticates
|
|||
// the requests.
|
|||
AccessToken string `json:"access_token"` |
|||
|
|||
// TokenType is the type of token.
|
|||
// The Type method returns either this or "Bearer", the default.
|
|||
TokenType string `json:"token_type,omitempty"` |
|||
|
|||
// RefreshToken is a token that's used by the application
|
|||
// (as opposed to the user) to refresh the access token
|
|||
// if it expires.
|
|||
RefreshToken string `json:"refresh_token,omitempty"` |
|||
|
|||
// Expiry is the optional expiration time of the access token.
|
|||
//
|
|||
// If zero, TokenSource implementations will reuse the same
|
|||
// token forever and RefreshToken or equivalent
|
|||
// mechanisms for that TokenSource will not be used.
|
|||
Expiry time.Time `json:"expiry,omitempty"` |
|||
|
|||
// raw optionally contains extra metadata from the server
|
|||
// when updating a token.
|
|||
raw interface{} |
|||
} |
|||
|
|||
// Type returns t.TokenType if non-empty, else "Bearer".
|
|||
func (t *Token) Type() string { |
|||
if strings.EqualFold(t.TokenType, "bearer") { |
|||
return "Bearer" |
|||
} |
|||
if strings.EqualFold(t.TokenType, "mac") { |
|||
return "MAC" |
|||
} |
|||
if strings.EqualFold(t.TokenType, "basic") { |
|||
return "Basic" |
|||
} |
|||
if t.TokenType != "" { |
|||
return t.TokenType |
|||
} |
|||
return "Bearer" |
|||
} |
|||
|
|||
// SetAuthHeader sets the Authorization header to r using the access
|
|||
// token in t.
|
|||
//
|
|||
// This method is unnecessary when using Transport or an HTTP Client
|
|||
// returned by this package.
|
|||
func (t *Token) SetAuthHeader(r *http.Request) { |
|||
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) |
|||
} |
|||
|
|||
// WithExtra returns a new Token that's a clone of t, but using the
|
|||
// provided raw extra map. This is only intended for use by packages
|
|||
// implementing derivative OAuth2 flows.
|
|||
func (t *Token) WithExtra(extra interface{}) *Token { |
|||
t2 := new(Token) |
|||
*t2 = *t |
|||
t2.raw = extra |
|||
return t2 |
|||
} |
|||
|
|||
// Extra returns an extra field.
|
|||
// Extra fields are key-value pairs returned by the server as a
|
|||
// part of the token retrieval response.
|
|||
func (t *Token) Extra(key string) interface{} { |
|||
if raw, ok := t.raw.(map[string]interface{}); ok { |
|||
return raw[key] |
|||
} |
|||
|
|||
vals, ok := t.raw.(url.Values) |
|||
if !ok { |
|||
return nil |
|||
} |
|||
|
|||
v := vals.Get(key) |
|||
switch s := strings.TrimSpace(v); strings.Count(s, ".") { |
|||
case 0: // Contains no "."; try to parse as int
|
|||
if i, err := strconv.ParseInt(s, 10, 64); err == nil { |
|||
return i |
|||
} |
|||
case 1: // Contains a single "."; try to parse as float
|
|||
if f, err := strconv.ParseFloat(s, 64); err == nil { |
|||
return f |
|||
} |
|||
} |
|||
|
|||
return v |
|||
} |
|||
|
|||
// expired reports whether the token is expired.
|
|||
// t must be non-nil.
|
|||
func (t *Token) expired() bool { |
|||
if t.Expiry.IsZero() { |
|||
return false |
|||
} |
|||
return t.Expiry.Add(-expiryDelta).Before(time.Now()) |
|||
} |
|||
|
|||
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
|
|||
func (t *Token) Valid() bool { |
|||
return t != nil && t.AccessToken != "" && !t.expired() |
|||
} |
|||
|
|||
// tokenFromInternal maps an *internal.Token struct into
|
|||
// a *Token struct.
|
|||
func tokenFromInternal(t *internal.Token) *Token { |
|||
if t == nil { |
|||
return nil |
|||
} |
|||
return &Token{ |
|||
AccessToken: t.AccessToken, |
|||
TokenType: t.TokenType, |
|||
RefreshToken: t.RefreshToken, |
|||
Expiry: t.Expiry, |
|||
raw: t.Raw, |
|||
} |
|||
} |
|||
|
|||
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
|||
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
|||
// with an error..
|
|||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { |
|||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return tokenFromInternal(tk), nil |
|||
} |
@ -0,0 +1,72 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package oauth2 |
|||
|
|||
import ( |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestTokenExtra(t *testing.T) { |
|||
type testCase struct { |
|||
key string |
|||
val interface{} |
|||
want interface{} |
|||
} |
|||
const key = "extra-key" |
|||
cases := []testCase{ |
|||
{key: key, val: "abc", want: "abc"}, |
|||
{key: key, val: 123, want: 123}, |
|||
{key: key, val: "", want: ""}, |
|||
{key: "other-key", val: "def", want: nil}, |
|||
} |
|||
for _, tc := range cases { |
|||
extra := make(map[string]interface{}) |
|||
extra[tc.key] = tc.val |
|||
tok := &Token{raw: extra} |
|||
if got, want := tok.Extra(key), tc.want; got != want { |
|||
t.Errorf("Extra(%q) = %q; want %q", key, got, want) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestTokenExpiry(t *testing.T) { |
|||
now := time.Now() |
|||
cases := []struct { |
|||
name string |
|||
tok *Token |
|||
want bool |
|||
}{ |
|||
{name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false}, |
|||
{name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true}, |
|||
{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true}, |
|||
} |
|||
for _, tc := range cases { |
|||
if got, want := tc.tok.expired(), tc.want; got != want { |
|||
t.Errorf("expired (%q) = %v; want %v", tc.name, got, want) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestTokenTypeMethod(t *testing.T) { |
|||
cases := []struct { |
|||
name string |
|||
tok *Token |
|||
want string |
|||
}{ |
|||
{name: "bearer-mixed_case", tok: &Token{TokenType: "beAREr"}, want: "Bearer"}, |
|||
{name: "default-bearer", tok: &Token{}, want: "Bearer"}, |
|||
{name: "basic", tok: &Token{TokenType: "basic"}, want: "Basic"}, |
|||
{name: "basic-capitalized", tok: &Token{TokenType: "Basic"}, want: "Basic"}, |
|||
{name: "mac", tok: &Token{TokenType: "mac"}, want: "MAC"}, |
|||
{name: "mac-caps", tok: &Token{TokenType: "MAC"}, want: "MAC"}, |
|||
{name: "mac-mixed_case", tok: &Token{TokenType: "mAc"}, want: "MAC"}, |
|||
} |
|||
for _, tc := range cases { |
|||
if got, want := tc.tok.Type(), tc.want; got != want { |
|||
t.Errorf("TokenType(%q) = %v; want %v", tc.name, got, want) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,132 @@ |
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
package oauth2 |
|||
|
|||
import ( |
|||
"errors" |
|||
"io" |
|||
"net/http" |
|||
"sync" |
|||
) |
|||
|
|||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
|||
// wrapping a base RoundTripper and adding an Authorization header
|
|||
// with a token from the supplied Sources.
|
|||
//
|
|||
// Transport is a low-level mechanism. Most code will use the
|
|||
// higher-level Config.Client method instead.
|
|||
type Transport struct { |
|||
// Source supplies the token to add to outgoing requests'
|
|||
// Authorization headers.
|
|||
Source TokenSource |
|||
|
|||
// Base is the base RoundTripper used to make HTTP requests.
|
|||
// If nil, http.DefaultTransport is used.
|
|||
Base http.RoundTripper |
|||
|
|||
mu sync.Mutex // guards modReq
|
|||
modReq map[*http.Request]*http.Request // original -> modified
|
|||
} |
|||
|
|||
// RoundTrip authorizes and authenticates the request with an
|
|||
// access token. If no token exists or token is expired,
|
|||
// tries to refresh/fetch a new token.
|
|||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
if t.Source == nil { |
|||
return nil, errors.New("oauth2: Transport's Source is nil") |
|||
} |
|||
token, err := t.Source.Token() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
req2 := cloneRequest(req) // per RoundTripper contract
|
|||
token.SetAuthHeader(req2) |
|||
t.setModReq(req, req2) |
|||
res, err := t.base().RoundTrip(req2) |
|||
if err != nil { |
|||
t.setModReq(req, nil) |
|||
return nil, err |
|||
} |
|||
res.Body = &onEOFReader{ |
|||
rc: res.Body, |
|||
fn: func() { t.setModReq(req, nil) }, |
|||
} |
|||
return res, nil |
|||
} |
|||
|
|||
// CancelRequest cancels an in-flight request by closing its connection.
|
|||
func (t *Transport) CancelRequest(req *http.Request) { |
|||
type canceler interface { |
|||
CancelRequest(*http.Request) |
|||
} |
|||
if cr, ok := t.base().(canceler); ok { |
|||
t.mu.Lock() |
|||
modReq := t.modReq[req] |
|||
delete(t.modReq, req) |
|||
t.mu.Unlock() |
|||
cr.CancelRequest(modReq) |
|||
} |
|||
} |
|||
|
|||
func (t *Transport) base() http.RoundTripper { |
|||
if t.Base != nil { |
|||
return t.Base |
|||
} |
|||
return http.DefaultTransport |
|||
} |
|||
|
|||
func (t *Transport) setModReq(orig, mod *http.Request) { |
|||
t.mu.Lock() |
|||
defer t.mu.Unlock() |
|||
if t.modReq == nil { |
|||
t.modReq = make(map[*http.Request]*http.Request) |
|||
} |
|||
if mod == nil { |
|||
delete(t.modReq, orig) |
|||
} else { |
|||
t.modReq[orig] = mod |
|||
} |
|||
} |
|||
|
|||
// cloneRequest returns a clone of the provided *http.Request.
|
|||
// The clone is a shallow copy of the struct and its Header map.
|
|||
func cloneRequest(r *http.Request) *http.Request { |
|||
// shallow copy of the struct
|
|||
r2 := new(http.Request) |
|||
*r2 = *r |
|||
// deep copy of the Header
|
|||
r2.Header = make(http.Header, len(r.Header)) |
|||
for k, s := range r.Header { |
|||
r2.Header[k] = append([]string(nil), s...) |
|||
} |
|||
return r2 |
|||
} |
|||
|
|||
type onEOFReader struct { |
|||
rc io.ReadCloser |
|||
fn func() |
|||
} |
|||
|
|||
func (r *onEOFReader) Read(p []byte) (n int, err error) { |
|||
n, err = r.rc.Read(p) |
|||
if err == io.EOF { |
|||
r.runFunc() |
|||
} |
|||
return |
|||
} |
|||
|
|||
func (r *onEOFReader) Close() error { |
|||
err := r.rc.Close() |
|||
r.runFunc() |
|||
return err |
|||
} |
|||
|
|||
func (r *onEOFReader) runFunc() { |
|||
if fn := r.fn; fn != nil { |
|||
fn() |
|||
r.fn = nil |
|||
} |
|||
} |
@ -0,0 +1,108 @@ |
|||
package oauth2 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
type tokenSource struct{ token *Token } |
|||
|
|||
func (t *tokenSource) Token() (*Token, error) { |
|||
return t.token, nil |
|||
} |
|||
|
|||
func TestTransportNilTokenSource(t *testing.T) { |
|||
tr := &Transport{} |
|||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) |
|||
defer server.Close() |
|||
client := &http.Client{Transport: tr} |
|||
res, err := client.Get(server.URL) |
|||
if err == nil { |
|||
t.Errorf("a nil Source was passed into the transport expected an error") |
|||
} |
|||
if res != nil { |
|||
t.Errorf("expected a nil response, got %v", res) |
|||
} |
|||
} |
|||
|
|||
func TestTransportTokenSource(t *testing.T) { |
|||
ts := &tokenSource{ |
|||
token: &Token{ |
|||
AccessToken: "abc", |
|||
}, |
|||
} |
|||
tr := &Transport{ |
|||
Source: ts, |
|||
} |
|||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) { |
|||
if r.Header.Get("Authorization") != "Bearer abc" { |
|||
t.Errorf("Transport doesn't set the Authorization header from the fetched token") |
|||
} |
|||
}) |
|||
defer server.Close() |
|||
client := &http.Client{Transport: tr} |
|||
res, err := client.Get(server.URL) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
res.Body.Close() |
|||
} |
|||
|
|||
// Test for case-sensitive token types, per https://github.com/golang/oauth2/issues/113
|
|||
func TestTransportTokenSourceTypes(t *testing.T) { |
|||
const val = "abc" |
|||
tests := []struct { |
|||
key string |
|||
val string |
|||
want string |
|||
}{ |
|||
{key: "bearer", val: val, want: "Bearer abc"}, |
|||
{key: "mac", val: val, want: "MAC abc"}, |
|||
{key: "basic", val: val, want: "Basic abc"}, |
|||
} |
|||
for _, tc := range tests { |
|||
ts := &tokenSource{ |
|||
token: &Token{ |
|||
AccessToken: tc.val, |
|||
TokenType: tc.key, |
|||
}, |
|||
} |
|||
tr := &Transport{ |
|||
Source: ts, |
|||
} |
|||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) { |
|||
if got, want := r.Header.Get("Authorization"), tc.want; got != want { |
|||
t.Errorf("Authorization header (%q) = %q; want %q", val, got, want) |
|||
} |
|||
}) |
|||
defer server.Close() |
|||
client := &http.Client{Transport: tr} |
|||
res, err := client.Get(server.URL) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
res.Body.Close() |
|||
} |
|||
} |
|||
|
|||
func TestTokenValidNoAccessToken(t *testing.T) { |
|||
token := &Token{} |
|||
if token.Valid() { |
|||
t.Errorf("Token should not be valid with no access token") |
|||
} |
|||
} |
|||
|
|||
func TestExpiredWithExpiry(t *testing.T) { |
|||
token := &Token{ |
|||
Expiry: time.Now().Add(-5 * time.Hour), |
|||
} |
|||
if token.Valid() { |
|||
t.Errorf("Token should not be valid if it expired in the past") |
|||
} |
|||
} |
|||
|
|||
func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { |
|||
return httptest.NewServer(http.HandlerFunc(handler)) |
|||
} |
@ -0,0 +1,16 @@ |
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
|||
// Use of this source code is governed by a BSD-style
|
|||
// license that can be found in the LICENSE file.
|
|||
|
|||
// Package vk provides constants for using OAuth2 to access VK.com.
|
|||
package vk // import "golang.org/x/oauth2/vk"
|
|||
|
|||
import ( |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
// Endpoint is VK's OAuth 2.0 endpoint.
|
|||
var Endpoint = oauth2.Endpoint{ |
|||
AuthURL: "https://oauth.vk.com/authorize", |
|||
TokenURL: "https://oauth.vk.com/access_token", |
|||
} |
@ -0,0 +1,438 @@ |
|||
// Copyright 2014 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
// Package metadata provides access to Google Compute Engine (GCE)
|
|||
// metadata and API service accounts.
|
|||
//
|
|||
// This package is a wrapper around the GCE metadata service,
|
|||
// as documented at https://developers.google.com/compute/docs/metadata.
|
|||
package metadata // import "google.golang.org/cloud/compute/metadata"
|
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"fmt" |
|||
"io/ioutil" |
|||
"net" |
|||
"net/http" |
|||
"net/url" |
|||
"os" |
|||
"runtime" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/net/context/ctxhttp" |
|||
|
|||
"google.golang.org/cloud/internal" |
|||
) |
|||
|
|||
const ( |
|||
// metadataIP is the documented metadata server IP address.
|
|||
metadataIP = "169.254.169.254" |
|||
|
|||
// metadataHostEnv is the environment variable specifying the
|
|||
// GCE metadata hostname. If empty, the default value of
|
|||
// metadataIP ("169.254.169.254") is used instead.
|
|||
// This is variable name is not defined by any spec, as far as
|
|||
// I know; it was made up for the Go package.
|
|||
metadataHostEnv = "GCE_METADATA_HOST" |
|||
) |
|||
|
|||
type cachedValue struct { |
|||
k string |
|||
trim bool |
|||
mu sync.Mutex |
|||
v string |
|||
} |
|||
|
|||
var ( |
|||
projID = &cachedValue{k: "project/project-id", trim: true} |
|||
projNum = &cachedValue{k: "project/numeric-project-id", trim: true} |
|||
instID = &cachedValue{k: "instance/id", trim: true} |
|||
) |
|||
|
|||
var ( |
|||
metaClient = &http.Client{ |
|||
Transport: &internal.Transport{ |
|||
Base: &http.Transport{ |
|||
Dial: (&net.Dialer{ |
|||
Timeout: 2 * time.Second, |
|||
KeepAlive: 30 * time.Second, |
|||
}).Dial, |
|||
ResponseHeaderTimeout: 2 * time.Second, |
|||
}, |
|||
}, |
|||
} |
|||
subscribeClient = &http.Client{ |
|||
Transport: &internal.Transport{ |
|||
Base: &http.Transport{ |
|||
Dial: (&net.Dialer{ |
|||
Timeout: 2 * time.Second, |
|||
KeepAlive: 30 * time.Second, |
|||
}).Dial, |
|||
}, |
|||
}, |
|||
} |
|||
) |
|||
|
|||
// NotDefinedError is returned when requested metadata is not defined.
|
|||
//
|
|||
// The underlying string is the suffix after "/computeMetadata/v1/".
|
|||
//
|
|||
// This error is not returned if the value is defined to be the empty
|
|||
// string.
|
|||
type NotDefinedError string |
|||
|
|||
func (suffix NotDefinedError) Error() string { |
|||
return fmt.Sprintf("metadata: GCE metadata %q not defined", string(suffix)) |
|||
} |
|||
|
|||
// Get returns a value from the metadata service.
|
|||
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
|
|||
//
|
|||
// If the GCE_METADATA_HOST environment variable is not defined, a default of
|
|||
// 169.254.169.254 will be used instead.
|
|||
//
|
|||
// If the requested metadata is not defined, the returned error will
|
|||
// be of type NotDefinedError.
|
|||
func Get(suffix string) (string, error) { |
|||
val, _, err := getETag(metaClient, suffix) |
|||
return val, err |
|||
} |
|||
|
|||
// getETag returns a value from the metadata service as well as the associated
|
|||
// ETag using the provided client. This func is otherwise equivalent to Get.
|
|||
func getETag(client *http.Client, suffix string) (value, etag string, err error) { |
|||
// Using a fixed IP makes it very difficult to spoof the metadata service in
|
|||
// a container, which is an important use-case for local testing of cloud
|
|||
// deployments. To enable spoofing of the metadata service, the environment
|
|||
// variable GCE_METADATA_HOST is first inspected to decide where metadata
|
|||
// requests shall go.
|
|||
host := os.Getenv(metadataHostEnv) |
|||
if host == "" { |
|||
// Using 169.254.169.254 instead of "metadata" here because Go
|
|||
// binaries built with the "netgo" tag and without cgo won't
|
|||
// know the search suffix for "metadata" is
|
|||
// ".google.internal", and this IP address is documented as
|
|||
// being stable anyway.
|
|||
host = metadataIP |
|||
} |
|||
url := "http://" + host + "/computeMetadata/v1/" + suffix |
|||
req, _ := http.NewRequest("GET", url, nil) |
|||
req.Header.Set("Metadata-Flavor", "Google") |
|||
res, err := client.Do(req) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
defer res.Body.Close() |
|||
if res.StatusCode == http.StatusNotFound { |
|||
return "", "", NotDefinedError(suffix) |
|||
} |
|||
if res.StatusCode != 200 { |
|||
return "", "", fmt.Errorf("status code %d trying to fetch %s", res.StatusCode, url) |
|||
} |
|||
all, err := ioutil.ReadAll(res.Body) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
return string(all), res.Header.Get("Etag"), nil |
|||
} |
|||
|
|||
func getTrimmed(suffix string) (s string, err error) { |
|||
s, err = Get(suffix) |
|||
s = strings.TrimSpace(s) |
|||
return |
|||
} |
|||
|
|||
func (c *cachedValue) get() (v string, err error) { |
|||
defer c.mu.Unlock() |
|||
c.mu.Lock() |
|||
if c.v != "" { |
|||
return c.v, nil |
|||
} |
|||
if c.trim { |
|||
v, err = getTrimmed(c.k) |
|||
} else { |
|||
v, err = Get(c.k) |
|||
} |
|||
if err == nil { |
|||
c.v = v |
|||
} |
|||
return |
|||
} |
|||
|
|||
var ( |
|||
onGCEOnce sync.Once |
|||
onGCE bool |
|||
) |
|||
|
|||
// OnGCE reports whether this process is running on Google Compute Engine.
|
|||
func OnGCE() bool { |
|||
onGCEOnce.Do(initOnGCE) |
|||
return onGCE |
|||
} |
|||
|
|||
func initOnGCE() { |
|||
onGCE = testOnGCE() |
|||
} |
|||
|
|||
func testOnGCE() bool { |
|||
// The user explicitly said they're on GCE, so trust them.
|
|||
if os.Getenv(metadataHostEnv) != "" { |
|||
return true |
|||
} |
|||
|
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
defer cancel() |
|||
|
|||
resc := make(chan bool, 2) |
|||
|
|||
// Try two strategies in parallel.
|
|||
// See https://github.com/GoogleCloudPlatform/gcloud-golang/issues/194
|
|||
go func() { |
|||
res, err := ctxhttp.Get(ctx, metaClient, "http://"+metadataIP) |
|||
if err != nil { |
|||
resc <- false |
|||
return |
|||
} |
|||
defer res.Body.Close() |
|||
resc <- res.Header.Get("Metadata-Flavor") == "Google" |
|||
}() |
|||
|
|||
go func() { |
|||
addrs, err := net.LookupHost("metadata.google.internal") |
|||
if err != nil || len(addrs) == 0 { |
|||
resc <- false |
|||
return |
|||
} |
|||
resc <- strsContains(addrs, metadataIP) |
|||
}() |
|||
|
|||
tryHarder := systemInfoSuggestsGCE() |
|||
if tryHarder { |
|||
res := <-resc |
|||
if res { |
|||
// The first strategy succeeded, so let's use it.
|
|||
return true |
|||
} |
|||
// Wait for either the DNS or metadata server probe to
|
|||
// contradict the other one and say we are running on
|
|||
// GCE. Give it a lot of time to do so, since the system
|
|||
// info already suggests we're running on a GCE BIOS.
|
|||
timer := time.NewTimer(5 * time.Second) |
|||
defer timer.Stop() |
|||
select { |
|||
case res = <-resc: |
|||
return res |
|||
case <-timer.C: |
|||
// Too slow. Who knows what this system is.
|
|||
return false |
|||
} |
|||
} |
|||
|
|||
// There's no hint from the system info that we're running on
|
|||
// GCE, so use the first probe's result as truth, whether it's
|
|||
// true or false. The goal here is to optimize for speed for
|
|||
// users who are NOT running on GCE. We can't assume that
|
|||
// either a DNS lookup or an HTTP request to a blackholed IP
|
|||
// address is fast. Worst case this should return when the
|
|||
// metaClient's Transport.ResponseHeaderTimeout or
|
|||
// Transport.Dial.Timeout fires (in two seconds).
|
|||
return <-resc |
|||
} |
|||
|
|||
// systemInfoSuggestsGCE reports whether the local system (without
|
|||
// doing network requests) suggests that we're running on GCE. If this
|
|||
// returns true, testOnGCE tries a bit harder to reach its metadata
|
|||
// server.
|
|||
func systemInfoSuggestsGCE() bool { |
|||
if runtime.GOOS != "linux" { |
|||
// We don't have any non-Linux clues available, at least yet.
|
|||
return false |
|||
} |
|||
slurp, _ := ioutil.ReadFile("/sys/class/dmi/id/product_name") |
|||
name := strings.TrimSpace(string(slurp)) |
|||
return name == "Google" || name == "Google Compute Engine" |
|||
} |
|||
|
|||
// Subscribe subscribes to a value from the metadata service.
|
|||
// The suffix is appended to "http://${GCE_METADATA_HOST}/computeMetadata/v1/".
|
|||
// The suffix may contain query parameters.
|
|||
//
|
|||
// Subscribe calls fn with the latest metadata value indicated by the provided
|
|||
// suffix. If the metadata value is deleted, fn is called with the empty string
|
|||
// and ok false. Subscribe blocks until fn returns a non-nil error or the value
|
|||
// is deleted. Subscribe returns the error value returned from the last call to
|
|||
// fn, which may be nil when ok == false.
|
|||
func Subscribe(suffix string, fn func(v string, ok bool) error) error { |
|||
const failedSubscribeSleep = time.Second * 5 |
|||
|
|||
// First check to see if the metadata value exists at all.
|
|||
val, lastETag, err := getETag(subscribeClient, suffix) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if err := fn(val, true); err != nil { |
|||
return err |
|||
} |
|||
|
|||
ok := true |
|||
if strings.ContainsRune(suffix, '?') { |
|||
suffix += "&wait_for_change=true&last_etag=" |
|||
} else { |
|||
suffix += "?wait_for_change=true&last_etag=" |
|||
} |
|||
for { |
|||
val, etag, err := getETag(subscribeClient, suffix+url.QueryEscape(lastETag)) |
|||
if err != nil { |
|||
if _, deleted := err.(NotDefinedError); !deleted { |
|||
time.Sleep(failedSubscribeSleep) |
|||
continue // Retry on other errors.
|
|||
} |
|||
ok = false |
|||
} |
|||
lastETag = etag |
|||
|
|||
if err := fn(val, ok); err != nil || !ok { |
|||
return err |
|||
} |
|||
} |
|||
} |
|||
|
|||
// ProjectID returns the current instance's project ID string.
|
|||
func ProjectID() (string, error) { return projID.get() } |
|||
|
|||
// NumericProjectID returns the current instance's numeric project ID.
|
|||
func NumericProjectID() (string, error) { return projNum.get() } |
|||
|
|||
// InternalIP returns the instance's primary internal IP address.
|
|||
func InternalIP() (string, error) { |
|||
return getTrimmed("instance/network-interfaces/0/ip") |
|||
} |
|||
|
|||
// ExternalIP returns the instance's primary external (public) IP address.
|
|||
func ExternalIP() (string, error) { |
|||
return getTrimmed("instance/network-interfaces/0/access-configs/0/external-ip") |
|||
} |
|||
|
|||
// Hostname returns the instance's hostname. This will be of the form
|
|||
// "<instanceID>.c.<projID>.internal".
|
|||
func Hostname() (string, error) { |
|||
return getTrimmed("instance/hostname") |
|||
} |
|||
|
|||
// InstanceTags returns the list of user-defined instance tags,
|
|||
// assigned when initially creating a GCE instance.
|
|||
func InstanceTags() ([]string, error) { |
|||
var s []string |
|||
j, err := Get("instance/tags") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if err := json.NewDecoder(strings.NewReader(j)).Decode(&s); err != nil { |
|||
return nil, err |
|||
} |
|||
return s, nil |
|||
} |
|||
|
|||
// InstanceID returns the current VM's numeric instance ID.
|
|||
func InstanceID() (string, error) { |
|||
return instID.get() |
|||
} |
|||
|
|||
// InstanceName returns the current VM's instance ID string.
|
|||
func InstanceName() (string, error) { |
|||
host, err := Hostname() |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return strings.Split(host, ".")[0], nil |
|||
} |
|||
|
|||
// Zone returns the current VM's zone, such as "us-central1-b".
|
|||
func Zone() (string, error) { |
|||
zone, err := getTrimmed("instance/zone") |
|||
// zone is of the form "projects/<projNum>/zones/<zoneName>".
|
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return zone[strings.LastIndex(zone, "/")+1:], nil |
|||
} |
|||
|
|||
// InstanceAttributes returns the list of user-defined attributes,
|
|||
// assigned when initially creating a GCE VM instance. The value of an
|
|||
// attribute can be obtained with InstanceAttributeValue.
|
|||
func InstanceAttributes() ([]string, error) { return lines("instance/attributes/") } |
|||
|
|||
// ProjectAttributes returns the list of user-defined attributes
|
|||
// applying to the project as a whole, not just this VM. The value of
|
|||
// an attribute can be obtained with ProjectAttributeValue.
|
|||
func ProjectAttributes() ([]string, error) { return lines("project/attributes/") } |
|||
|
|||
func lines(suffix string) ([]string, error) { |
|||
j, err := Get(suffix) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
s := strings.Split(strings.TrimSpace(j), "\n") |
|||
for i := range s { |
|||
s[i] = strings.TrimSpace(s[i]) |
|||
} |
|||
return s, nil |
|||
} |
|||
|
|||
// InstanceAttributeValue returns the value of the provided VM
|
|||
// instance attribute.
|
|||
//
|
|||
// If the requested attribute is not defined, the returned error will
|
|||
// be of type NotDefinedError.
|
|||
//
|
|||
// InstanceAttributeValue may return ("", nil) if the attribute was
|
|||
// defined to be the empty string.
|
|||
func InstanceAttributeValue(attr string) (string, error) { |
|||
return Get("instance/attributes/" + attr) |
|||
} |
|||
|
|||
// ProjectAttributeValue returns the value of the provided
|
|||
// project attribute.
|
|||
//
|
|||
// If the requested attribute is not defined, the returned error will
|
|||
// be of type NotDefinedError.
|
|||
//
|
|||
// ProjectAttributeValue may return ("", nil) if the attribute was
|
|||
// defined to be the empty string.
|
|||
func ProjectAttributeValue(attr string) (string, error) { |
|||
return Get("project/attributes/" + attr) |
|||
} |
|||
|
|||
// Scopes returns the service account scopes for the given account.
|
|||
// The account may be empty or the string "default" to use the instance's
|
|||
// main account.
|
|||
func Scopes(serviceAccount string) ([]string, error) { |
|||
if serviceAccount == "" { |
|||
serviceAccount = "default" |
|||
} |
|||
return lines("instance/service-accounts/" + serviceAccount + "/scopes") |
|||
} |
|||
|
|||
func strsContains(ss []string, s string) bool { |
|||
for _, v := range ss { |
|||
if v == s { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
@ -0,0 +1,48 @@ |
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
package metadata |
|||
|
|||
import ( |
|||
"os" |
|||
"sync" |
|||
"testing" |
|||
) |
|||
|
|||
func TestOnGCE_Stress(t *testing.T) { |
|||
if testing.Short() { |
|||
t.Skip("skipping in -short mode") |
|||
} |
|||
var last bool |
|||
for i := 0; i < 100; i++ { |
|||
onGCEOnce = sync.Once{} |
|||
|
|||
now := OnGCE() |
|||
if i > 0 && now != last { |
|||
t.Errorf("%d. changed from %v to %v", i, last, now) |
|||
} |
|||
last = now |
|||
} |
|||
t.Logf("OnGCE() = %v", last) |
|||
} |
|||
|
|||
func TestOnGCE_Force(t *testing.T) { |
|||
onGCEOnce = sync.Once{} |
|||
old := os.Getenv(metadataHostEnv) |
|||
defer os.Setenv(metadataHostEnv, old) |
|||
os.Setenv(metadataHostEnv, "127.0.0.1") |
|||
if !OnGCE() { |
|||
t.Error("OnGCE() = false; want true") |
|||
} |
|||
} |
@ -0,0 +1,257 @@ |
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
// Package bundler supports bundling (batching) of items. Bundling amortizes an
|
|||
// action with fixed costs over multiple items. For example, if an API provides
|
|||
// an RPC that accepts a list of items as input, but clients would prefer
|
|||
// adding items one at a time, then a Bundler can accept individual items from
|
|||
// the client and bundle many of them into a single RPC.
|
|||
package bundler |
|||
|
|||
import ( |
|||
"errors" |
|||
"reflect" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
DefaultDelayThreshold = time.Second |
|||
DefaultBundleCountThreshold = 10 |
|||
DefaultBundleByteThreshold = 1e6 // 1M
|
|||
DefaultBufferedByteLimit = 1e9 // 1G
|
|||
) |
|||
|
|||
var ( |
|||
// ErrOverflow indicates that Bundler's stored bytes exceeds its BufferedByteLimit.
|
|||
ErrOverflow = errors.New("bundler reached buffered byte limit") |
|||
|
|||
// ErrOversizedItem indicates that an item's size exceeds the maximum bundle size.
|
|||
ErrOversizedItem = errors.New("item size exceeds bundle byte limit") |
|||
) |
|||
|
|||
// A Bundler collects items added to it into a bundle until the bundle
|
|||
// exceeds a given size, then calls a user-provided function to handle the bundle.
|
|||
type Bundler struct { |
|||
// Starting from the time that the first message is added to a bundle, once
|
|||
// this delay has passed, handle the bundle. The default is DefaultDelayThreshold.
|
|||
DelayThreshold time.Duration |
|||
|
|||
// Once a bundle has this many items, handle the bundle. Since only one
|
|||
// item at a time is added to a bundle, no bundle will exceed this
|
|||
// threshold, so it also serves as a limit. The default is
|
|||
// DefaultBundleCountThreshold.
|
|||
BundleCountThreshold int |
|||
|
|||
// Once the number of bytes in current bundle reaches this threshold, handle
|
|||
// the bundle. The default is DefaultBundleByteThreshold. This triggers handling,
|
|||
// but does not cap the total size of a bundle.
|
|||
BundleByteThreshold int |
|||
|
|||
// The maximum size of a bundle, in bytes. Zero means unlimited.
|
|||
BundleByteLimit int |
|||
|
|||
// The maximum number of bytes that the Bundler will keep in memory before
|
|||
// returning ErrOverflow. The default is DefaultBufferedByteLimit.
|
|||
BufferedByteLimit int |
|||
|
|||
handler func(interface{}) // called to handle a bundle
|
|||
itemSliceZero reflect.Value // nil (zero value) for slice of items
|
|||
donec chan struct{} // closed when the Bundler is closed
|
|||
handlec chan int // sent to when a bundle is ready for handling
|
|||
timer *time.Timer // implements DelayThreshold
|
|||
|
|||
mu sync.Mutex |
|||
bufferedSize int // total bytes buffered
|
|||
closedBundles []bundle // bundles waiting to be handled
|
|||
curBundle bundle // incoming items added to this bundle
|
|||
calledc chan struct{} // closed and re-created after handler is called
|
|||
} |
|||
|
|||
type bundle struct { |
|||
items reflect.Value // slice of item type
|
|||
size int // size in bytes of all items
|
|||
} |
|||
|
|||
// NewBundler creates a new Bundler. When you are finished with a Bundler, call
|
|||
// its Close method.
|
|||
//
|
|||
// itemExample is a value of the type that will be bundled. For example, if you
|
|||
// want to create bundles of *Entry, you could pass &Entry{} for itemExample.
|
|||
//
|
|||
// handler is a function that will be called on each bundle. If itemExample is
|
|||
// of type T, the the argument to handler is of type []T.
|
|||
func NewBundler(itemExample interface{}, handler func(interface{})) *Bundler { |
|||
b := &Bundler{ |
|||
DelayThreshold: DefaultDelayThreshold, |
|||
BundleCountThreshold: DefaultBundleCountThreshold, |
|||
BundleByteThreshold: DefaultBundleByteThreshold, |
|||
BufferedByteLimit: DefaultBufferedByteLimit, |
|||
|
|||
handler: handler, |
|||
itemSliceZero: reflect.Zero(reflect.SliceOf(reflect.TypeOf(itemExample))), |
|||
donec: make(chan struct{}), |
|||
handlec: make(chan int, 1), |
|||
calledc: make(chan struct{}), |
|||
timer: time.NewTimer(1000 * time.Hour), // harmless initial timeout
|
|||
} |
|||
b.curBundle.items = b.itemSliceZero |
|||
go b.background() |
|||
return b |
|||
} |
|||
|
|||
// Add adds item to the current bundle. It marks the bundle for handling and
|
|||
// starts a new one if any of the thresholds or limits are exceeded.
|
|||
//
|
|||
// If the item's size exceeds the maximum bundle size (Bundler.BundleByteLimit), then
|
|||
// the item can never be handled. Add returns ErrOversizedItem in this case.
|
|||
//
|
|||
// If adding the item would exceed the maximum memory allowed (Bundler.BufferedByteLimit),
|
|||
// Add returns ErrOverflow.
|
|||
//
|
|||
// Add never blocks.
|
|||
func (b *Bundler) Add(item interface{}, size int) error { |
|||
// If this item exceeds the maximum size of a bundle,
|
|||
// we can never send it.
|
|||
if b.BundleByteLimit > 0 && size > b.BundleByteLimit { |
|||
return ErrOversizedItem |
|||
} |
|||
b.mu.Lock() |
|||
defer b.mu.Unlock() |
|||
// If adding this item would exceed our allotted memory
|
|||
// footprint, we can't accept it.
|
|||
if b.bufferedSize+size > b.BufferedByteLimit { |
|||
return ErrOverflow |
|||
} |
|||
// If adding this item to the current bundle would cause it to exceed the
|
|||
// maximum bundle size, close the current bundle and start a new one.
|
|||
if b.BundleByteLimit > 0 && b.curBundle.size+size > b.BundleByteLimit { |
|||
b.closeAndHandleBundle() |
|||
} |
|||
// Add the item.
|
|||
b.curBundle.items = reflect.Append(b.curBundle.items, reflect.ValueOf(item)) |
|||
b.curBundle.size += size |
|||
b.bufferedSize += size |
|||
// If this is the first item in the bundle, restart the timer.
|
|||
if b.curBundle.items.Len() == 1 { |
|||
b.timer.Reset(b.DelayThreshold) |
|||
} |
|||
// If the current bundle equals the count threshold, close it.
|
|||
if b.curBundle.items.Len() == b.BundleCountThreshold { |
|||
b.closeAndHandleBundle() |
|||
} |
|||
// If the current bundle equals or exceeds the byte threshold, close it.
|
|||
if b.curBundle.size >= b.BundleByteThreshold { |
|||
b.closeAndHandleBundle() |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// Flush waits until all items in the Bundler have been handled.
|
|||
func (b *Bundler) Flush() { |
|||
b.mu.Lock() |
|||
b.closeBundle() |
|||
// Unconditionally trigger the handling goroutine, to ensure calledc is closed
|
|||
// even if there are no outstanding bundles.
|
|||
select { |
|||
case b.handlec <- 1: |
|||
default: |
|||
} |
|||
calledc := b.calledc // remember locally, because it may change
|
|||
b.mu.Unlock() |
|||
<-calledc |
|||
} |
|||
|
|||
// Close calls Flush, then shuts down the Bundler. Close should always be
|
|||
// called on a Bundler when it is no longer needed. You must wait for all calls
|
|||
// to Add to complete before calling Close. Calling Add concurrently with Close
|
|||
// may result in the added items being ignored.
|
|||
func (b *Bundler) Close() { |
|||
b.Flush() |
|||
b.mu.Lock() |
|||
b.timer.Stop() |
|||
b.mu.Unlock() |
|||
close(b.donec) |
|||
} |
|||
|
|||
func (b *Bundler) closeAndHandleBundle() { |
|||
if b.closeBundle() { |
|||
// We have created a closed bundle.
|
|||
// Send to handlec without blocking.
|
|||
select { |
|||
case b.handlec <- 1: |
|||
default: |
|||
} |
|||
} |
|||
} |
|||
|
|||
// closeBundle finishes the current bundle, adds it to the list of closed
|
|||
// bundles and informs the background goroutine that there are bundles ready
|
|||
// for processing.
|
|||
//
|
|||
// This should always be called with b.mu held.
|
|||
func (b *Bundler) closeBundle() bool { |
|||
if b.curBundle.items.Len() == 0 { |
|||
return false |
|||
} |
|||
b.closedBundles = append(b.closedBundles, b.curBundle) |
|||
b.curBundle.items = b.itemSliceZero |
|||
b.curBundle.size = 0 |
|||
return true |
|||
} |
|||
|
|||
// background runs in a separate goroutine, waiting for events and handling
|
|||
// bundles.
|
|||
func (b *Bundler) background() { |
|||
done := false |
|||
for { |
|||
timedOut := false |
|||
// Wait for something to happen.
|
|||
select { |
|||
case <-b.handlec: |
|||
case <-b.donec: |
|||
done = true |
|||
case <-b.timer.C: |
|||
timedOut = true |
|||
} |
|||
// Handle closed bundles.
|
|||
b.mu.Lock() |
|||
if timedOut { |
|||
b.closeBundle() |
|||
} |
|||
buns := b.closedBundles |
|||
b.closedBundles = nil |
|||
// Closing calledc means we've sent all bundles. We need
|
|||
// a new channel for the next set of bundles, which may start
|
|||
// accumulating as soon as we release the lock.
|
|||
calledc := b.calledc |
|||
b.calledc = make(chan struct{}) |
|||
b.mu.Unlock() |
|||
for i, bun := range buns { |
|||
b.handler(bun.items.Interface()) |
|||
// Drop the bundle's items, reducing our memory footprint.
|
|||
buns[i].items = reflect.Value{} // buns[i] because bun is a copy
|
|||
// Note immediately that we have more space, so Adds that occur
|
|||
// during this loop will have a chance of succeeding.
|
|||
b.mu.Lock() |
|||
b.bufferedSize -= bun.size |
|||
b.mu.Unlock() |
|||
} |
|||
// Signal that we've sent all outstanding bundles.
|
|||
close(calledc) |
|||
if done { |
|||
break |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,209 @@ |
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
package bundler |
|||
|
|||
import ( |
|||
"reflect" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestBundlerCount1(t *testing.T) { |
|||
// Unbundled case: one item per bundle.
|
|||
handler := &testHandler{} |
|||
b := NewBundler(int(0), handler.handle) |
|||
b.BundleCountThreshold = 1 |
|||
b.DelayThreshold = time.Second |
|||
|
|||
for i := 0; i < 3; i++ { |
|||
if err := b.Add(i, 1); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
b.Close() |
|||
got := handler.bundles |
|||
want := [][]int{{0}, {1}, {2}} |
|||
if !reflect.DeepEqual(got, want) { |
|||
t.Errorf("bundles: got %v, want %v", got, want) |
|||
} |
|||
// All bundles should have been handled "immediately": much less
|
|||
// than the delay threshold of 1s.
|
|||
tgot := quantizeTimes(handler.times, 100*time.Millisecond) |
|||
twant := []int{0, 0, 0} |
|||
if !reflect.DeepEqual(tgot, twant) { |
|||
t.Errorf("times: got %v, want %v", tgot, twant) |
|||
} |
|||
} |
|||
|
|||
func TestBundlerCount3(t *testing.T) { |
|||
handler := &testHandler{} |
|||
b := NewBundler(int(0), handler.handle) |
|||
b.BundleCountThreshold = 3 |
|||
b.DelayThreshold = 1 * time.Millisecond |
|||
// Add 8 items.
|
|||
// The first two bundles of 3 should both be handled quickly.
|
|||
// The third bundle of 2 should not be handled for about 1 ms.
|
|||
for i := 0; i < 8; i++ { |
|||
if err := b.Add(i, 1); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
time.Sleep(5 * b.DelayThreshold) |
|||
// We should not need to close the bundler.
|
|||
|
|||
bgot := handler.bundles |
|||
bwant := [][]int{{0, 1, 2}, {3, 4, 5}, {6, 7}} |
|||
if !reflect.DeepEqual(bgot, bwant) { |
|||
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
|||
} |
|||
|
|||
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
|||
twant := []int{0, 0, 1} |
|||
if !reflect.DeepEqual(tgot, twant) { |
|||
t.Errorf("times: got %v, want %v", tgot, twant) |
|||
} |
|||
} |
|||
|
|||
func TestBundlerByteThreshold(t *testing.T) { |
|||
handler := &testHandler{} |
|||
b := NewBundler(int(0), handler.handle) |
|||
b.BundleCountThreshold = 10 |
|||
b.BundleByteThreshold = 3 |
|||
add := func(i interface{}, s int) { |
|||
if err := b.Add(i, s); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
|
|||
add(1, 1) |
|||
add(2, 2) |
|||
// Hit byte threshold: bundle = 1, 2
|
|||
add(3, 1) |
|||
add(4, 1) |
|||
add(5, 2) |
|||
// Passed byte threshold, but not limit: bundle = 3, 4, 5
|
|||
add(6, 1) |
|||
b.Close() |
|||
bgot := handler.bundles |
|||
bwant := [][]int{{1, 2}, {3, 4, 5}, {6}} |
|||
if !reflect.DeepEqual(bgot, bwant) { |
|||
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
|||
} |
|||
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
|||
twant := []int{0, 0, 0} |
|||
if !reflect.DeepEqual(tgot, twant) { |
|||
t.Errorf("times: got %v, want %v", tgot, twant) |
|||
} |
|||
} |
|||
|
|||
func TestBundlerLimit(t *testing.T) { |
|||
handler := &testHandler{} |
|||
b := NewBundler(int(0), handler.handle) |
|||
b.BundleCountThreshold = 10 |
|||
b.BundleByteLimit = 3 |
|||
add := func(i interface{}, s int) { |
|||
if err := b.Add(i, s); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
|
|||
add(1, 1) |
|||
add(2, 2) |
|||
// Hit byte limit: bundle = 1, 2
|
|||
add(3, 1) |
|||
add(4, 1) |
|||
add(5, 2) |
|||
// Exceeded byte limit: bundle = 3, 4
|
|||
add(6, 2) |
|||
// Exceeded byte limit: bundle = 5
|
|||
b.Close() |
|||
bgot := handler.bundles |
|||
bwant := [][]int{{1, 2}, {3, 4}, {5}, {6}} |
|||
if !reflect.DeepEqual(bgot, bwant) { |
|||
t.Errorf("bundles: got %v, want %v", bgot, bwant) |
|||
} |
|||
tgot := quantizeTimes(handler.times, b.DelayThreshold) |
|||
twant := []int{0, 0, 0, 0} |
|||
if !reflect.DeepEqual(tgot, twant) { |
|||
t.Errorf("times: got %v, want %v", tgot, twant) |
|||
} |
|||
} |
|||
|
|||
func TestBundlerErrors(t *testing.T) { |
|||
// Use a handler that blocks forever, to force the bundler to run out of
|
|||
// memory.
|
|||
b := NewBundler(int(0), func(interface{}) { select {} }) |
|||
b.BundleByteLimit = 3 |
|||
b.BufferedByteLimit = 10 |
|||
|
|||
if got, want := b.Add(1, 4), ErrOversizedItem; got != want { |
|||
t.Fatalf("got %v, want %v", got, want) |
|||
} |
|||
|
|||
for i := 0; i < 5; i++ { |
|||
if err := b.Add(i, 2); err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
} |
|||
if got, want := b.Add(5, 1), ErrOverflow; got != want { |
|||
t.Fatalf("got %v, want %v", got, want) |
|||
} |
|||
} |
|||
|
|||
type testHandler struct { |
|||
bundles [][]int |
|||
times []time.Time |
|||
} |
|||
|
|||
func (t *testHandler) handle(b interface{}) { |
|||
t.bundles = append(t.bundles, b.([]int)) |
|||
t.times = append(t.times, time.Now()) |
|||
} |
|||
|
|||
// Round times to the nearest q and express them as the number of q
|
|||
// since the first time.
|
|||
// E.g. if q is 100ms, then a time within 50ms of the first time
|
|||
// will be represented as 0, a time 150 to 250ms of the first time
|
|||
// we be represented as 1, etc.
|
|||
func quantizeTimes(times []time.Time, q time.Duration) []int { |
|||
var rs []int |
|||
for _, t := range times { |
|||
d := t.Sub(times[0]) |
|||
r := int((d + q/2) / q) |
|||
rs = append(rs, r) |
|||
} |
|||
return rs |
|||
} |
|||
|
|||
func TestQuantizeTimes(t *testing.T) { |
|||
quantum := 100 * time.Millisecond |
|||
for _, test := range []struct { |
|||
millis []int // times in milliseconds
|
|||
want []int |
|||
}{ |
|||
{[]int{10, 20, 30}, []int{0, 0, 0}}, |
|||
{[]int{0, 49, 50, 90}, []int{0, 0, 1, 1}}, |
|||
{[]int{0, 95, 170, 315}, []int{0, 1, 2, 3}}, |
|||
} { |
|||
var times []time.Time |
|||
for _, ms := range test.millis { |
|||
times = append(times, time.Unix(0, int64(ms*1e6))) |
|||
} |
|||
got := quantizeTimes(times, quantum) |
|||
if !reflect.DeepEqual(got, test.want) { |
|||
t.Errorf("%v: got %v, want %v", test.millis, got, test.want) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,128 @@ |
|||
// Copyright 2014 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
// Package internal provides support for the cloud packages.
|
|||
//
|
|||
// Users should not import this package directly.
|
|||
package internal |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"sync" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
type contextKey struct{} |
|||
|
|||
func WithContext(parent context.Context, projID string, c *http.Client) context.Context { |
|||
if c == nil { |
|||
panic("nil *http.Client passed to WithContext") |
|||
} |
|||
if projID == "" { |
|||
panic("empty project ID passed to WithContext") |
|||
} |
|||
return context.WithValue(parent, contextKey{}, &cloudContext{ |
|||
ProjectID: projID, |
|||
HTTPClient: c, |
|||
}) |
|||
} |
|||
|
|||
const userAgent = "gcloud-golang/0.1" |
|||
|
|||
type cloudContext struct { |
|||
ProjectID string |
|||
HTTPClient *http.Client |
|||
|
|||
mu sync.Mutex // guards svc
|
|||
svc map[string]interface{} // e.g. "storage" => *rawStorage.Service
|
|||
} |
|||
|
|||
// Service returns the result of the fill function if it's never been
|
|||
// called before for the given name (which is assumed to be an API
|
|||
// service name, like "datastore"). If it has already been cached, the fill
|
|||
// func is not run.
|
|||
// It's safe for concurrent use by multiple goroutines.
|
|||
func Service(ctx context.Context, name string, fill func(*http.Client) interface{}) interface{} { |
|||
return cc(ctx).service(name, fill) |
|||
} |
|||
|
|||
func (c *cloudContext) service(name string, fill func(*http.Client) interface{}) interface{} { |
|||
c.mu.Lock() |
|||
defer c.mu.Unlock() |
|||
|
|||
if c.svc == nil { |
|||
c.svc = make(map[string]interface{}) |
|||
} else if v, ok := c.svc[name]; ok { |
|||
return v |
|||
} |
|||
v := fill(c.HTTPClient) |
|||
c.svc[name] = v |
|||
return v |
|||
} |
|||
|
|||
// Transport is an http.RoundTripper that appends
|
|||
// Google Cloud client's user-agent to the original
|
|||
// request's user-agent header.
|
|||
type Transport struct { |
|||
// Base is the actual http.RoundTripper
|
|||
// requests will use. It must not be nil.
|
|||
Base http.RoundTripper |
|||
} |
|||
|
|||
// RoundTrip appends a user-agent to the existing user-agent
|
|||
// header and delegates the request to the base http.RoundTripper.
|
|||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
req = cloneRequest(req) |
|||
ua := req.Header.Get("User-Agent") |
|||
if ua == "" { |
|||
ua = userAgent |
|||
} else { |
|||
ua = fmt.Sprintf("%s %s", ua, userAgent) |
|||
} |
|||
req.Header.Set("User-Agent", ua) |
|||
return t.Base.RoundTrip(req) |
|||
} |
|||
|
|||
// cloneRequest returns a clone of the provided *http.Request.
|
|||
// The clone is a shallow copy of the struct and its Header map.
|
|||
func cloneRequest(r *http.Request) *http.Request { |
|||
// shallow copy of the struct
|
|||
r2 := new(http.Request) |
|||
*r2 = *r |
|||
// deep copy of the Header
|
|||
r2.Header = make(http.Header) |
|||
for k, s := range r.Header { |
|||
r2.Header[k] = s |
|||
} |
|||
return r2 |
|||
} |
|||
|
|||
func ProjID(ctx context.Context) string { |
|||
return cc(ctx).ProjectID |
|||
} |
|||
|
|||
func HTTPClient(ctx context.Context) *http.Client { |
|||
return cc(ctx).HTTPClient |
|||
} |
|||
|
|||
// cc returns the internal *cloudContext (cc) state for a context.Context.
|
|||
// It panics if the user did it wrong.
|
|||
func cc(ctx context.Context) *cloudContext { |
|||
if c, ok := ctx.Value(contextKey{}).(*cloudContext); ok { |
|||
return c |
|||
} |
|||
panic("invalid context.Context type; it should be created with cloud.NewContext") |
|||
} |
@ -0,0 +1,60 @@ |
|||
// Copyright 2014 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
// Package testutil contains helper functions for writing tests.
|
|||
package testutil |
|||
|
|||
import ( |
|||
"io/ioutil" |
|||
"log" |
|||
"os" |
|||
|
|||
"golang.org/x/net/context" |
|||
"golang.org/x/oauth2" |
|||
"golang.org/x/oauth2/google" |
|||
) |
|||
|
|||
const ( |
|||
envProjID = "GCLOUD_TESTS_GOLANG_PROJECT_ID" |
|||
envPrivateKey = "GCLOUD_TESTS_GOLANG_KEY" |
|||
) |
|||
|
|||
// ProjID returns the project ID to use in integration tests, or the empty
|
|||
// string if none is configured.
|
|||
func ProjID() string { |
|||
projID := os.Getenv(envProjID) |
|||
if projID == "" { |
|||
return "" |
|||
} |
|||
return projID |
|||
} |
|||
|
|||
// TokenSource returns the OAuth2 token source to use in integration tests,
|
|||
// or nil if none is configured. TokenSource will log.Fatal if the token
|
|||
// source is specified but missing or invalid.
|
|||
func TokenSource(ctx context.Context, scopes ...string) oauth2.TokenSource { |
|||
key := os.Getenv(envPrivateKey) |
|||
if key == "" { |
|||
return nil |
|||
} |
|||
jsonKey, err := ioutil.ReadFile(key) |
|||
if err != nil { |
|||
log.Fatalf("Cannot read the JSON key file, err: %v", err) |
|||
} |
|||
conf, err := google.JWTConfigFromJSON(jsonKey, scopes...) |
|||
if err != nil { |
|||
log.Fatalf("google.JWTConfigFromJSON: %v", err) |
|||
} |
|||
return conf.TokenSource(ctx) |
|||
} |
@ -0,0 +1,186 @@ |
|||
/* |
|||
Copyright 2016 Google Inc. All Rights Reserved. |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
|||
*/ |
|||
|
|||
package testutil |
|||
|
|||
import ( |
|||
"fmt" |
|||
"reflect" |
|||
) |
|||
|
|||
// TestIteratorNext tests the Next method of a standard client iterator (see
|
|||
// https://github.com/GoogleCloudPlatform/gcloud-golang/wiki/Iterator-Guidelines).
|
|||
//
|
|||
// This function assumes that an iterator has already been created, and that
|
|||
// the underlying sequence to be iterated over already exists. want should be a
|
|||
// slice that contains the elements of this sequence. It must contain at least
|
|||
// one element.
|
|||
//
|
|||
// done is the special error value that signals that the iterator has provided all the items.
|
|||
//
|
|||
// next is a function that returns the result of calling Next on the iterator. It can usually be
|
|||
// defined as
|
|||
// func() (interface{}, error) { return iter.Next() }
|
|||
//
|
|||
// TestIteratorNext checks that the iterator returns all the elements of want
|
|||
// in order, followed by (zero, done). It also confirms that subsequent calls
|
|||
// to next also return (zero, done).
|
|||
//
|
|||
// On success, TestIteratorNext returns ("", true). On failure, it returns a
|
|||
// suitable error message and false.
|
|||
func TestIteratorNext(want interface{}, done error, next func() (interface{}, error)) (string, bool) { |
|||
wVal := reflect.ValueOf(want) |
|||
if wVal.Kind() != reflect.Slice { |
|||
return "'want' must be a slice", false |
|||
} |
|||
for i := 0; i < wVal.Len(); i++ { |
|||
got, err := next() |
|||
if err != nil { |
|||
return fmt.Sprintf("#%d: got %v, expected an item", i, err), false |
|||
} |
|||
w := wVal.Index(i).Interface() |
|||
if !reflect.DeepEqual(got, w) { |
|||
return fmt.Sprintf("#%d: got %+v, want %+v", i, got, w), false |
|||
} |
|||
} |
|||
// We now should see (<zero value of item type>, done), no matter how many
|
|||
// additional calls we make.
|
|||
zero := reflect.Zero(wVal.Type().Elem()).Interface() |
|||
for i := 0; i < 3; i++ { |
|||
got, err := next() |
|||
if err != done { |
|||
return fmt.Sprintf("at end: got error %v, want done", err), false |
|||
} |
|||
// Since err == done, got should be zero.
|
|||
if got != zero { |
|||
return fmt.Sprintf("got %+v with done, want zero %T", got, zero), false |
|||
} |
|||
} |
|||
return "", true |
|||
} |
|||
|
|||
// PagingIterator describes the standard client iterator pattern with paging as best as possible in Go.
|
|||
// See https://github.com/GoogleCloudPlatform/gcloud-golang/wiki/Iterator-Guidelines.
|
|||
type PagingIterator interface { |
|||
SetPageSize(int32) |
|||
SetPageToken(string) |
|||
NextPageToken() string |
|||
// NextPage() ([]T, error)
|
|||
} |
|||
|
|||
// TestIteratorNextPageExact tests the NextPage method of a standard client
|
|||
// iterator with paging (see PagingIterator).
|
|||
//
|
|||
// This function assumes that the underlying sequence to be iterated over
|
|||
// already exists. want should be a slice that contains the elements of this
|
|||
// sequence. It must contain at least three elements, in order to test
|
|||
// non-trivial paging behavior.
|
|||
//
|
|||
// done is the special error value that signals that the iterator has provided all the items.
|
|||
//
|
|||
// defaultPageSize is the page size to use when the user has not called SetPageSize, or calls
|
|||
// it with a value <= 0.
|
|||
//
|
|||
// newIter should return a new iterator each time it is called.
|
|||
//
|
|||
// nextPage should return the result of calling NextPage on the iterator. It can usually be
|
|||
// defined as
|
|||
// func(i testutil.PagingIterator) (interface{}, error) { return i.(*<iteratorType>).NextPage() }
|
|||
//
|
|||
// TestIteratorNextPageExact checks that the iterator returns all the elements
|
|||
// of want in order, divided into pages of the exactly the right size. It
|
|||
// confirms that if the last page is partial, done is returned along with it,
|
|||
// and in any case, done is returned subsequently along with a zero-length
|
|||
// slice.
|
|||
//
|
|||
// On success, TestIteratorNextPageExact returns ("", true). On failure, it returns a
|
|||
// suitable error message and false.
|
|||
func TestIteratorNextPageExact(want interface{}, done error, defaultPageSize int32, newIter func() PagingIterator, nextPage func(PagingIterator) (interface{}, error)) (string, bool) { |
|||
wVal := reflect.ValueOf(want) |
|||
if wVal.Kind() != reflect.Slice { |
|||
return "'want' must be a slice", false |
|||
} |
|||
if wVal.Len() < 3 { |
|||
return "need at least 3 values for 'want' to effectively test paging", false |
|||
} |
|||
const doNotSetPageSize = -999 |
|||
for _, pageSize := range []int32{doNotSetPageSize, -7, 0, 1, 2, int32(wVal.Len()), int32(wVal.Len()) + 10} { |
|||
adjustedPageSize := int(pageSize) |
|||
if pageSize <= 0 { |
|||
adjustedPageSize = int(defaultPageSize) |
|||
} |
|||
// Create the pages we expect to see.
|
|||
var wantPages []interface{} |
|||
for i, j := 0, adjustedPageSize; i < wVal.Len(); i, j = j, j+adjustedPageSize { |
|||
if j > wVal.Len() { |
|||
j = wVal.Len() |
|||
} |
|||
wantPages = append(wantPages, wVal.Slice(i, j).Interface()) |
|||
} |
|||
for _, usePageToken := range []bool{false, true} { |
|||
it := newIter() |
|||
if pageSize != doNotSetPageSize { |
|||
it.SetPageSize(pageSize) |
|||
} |
|||
for i, wantPage := range wantPages { |
|||
gotPage, err := nextPage(it) |
|||
if err != nil && err != done { |
|||
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d: got %v, expected a page", |
|||
usePageToken, pageSize, i, err), false |
|||
} |
|||
if !reflect.DeepEqual(gotPage, wantPage) { |
|||
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d:\ngot %v\nwant %+v", |
|||
usePageToken, pageSize, i, gotPage, wantPage), false |
|||
} |
|||
// If the last page is partial, NextPage must return done.
|
|||
if reflect.ValueOf(gotPage).Len() < adjustedPageSize && err != done { |
|||
return fmt.Sprintf("usePageToken %v, pageSize %d, #%d: expected done on partial page, got %v", |
|||
usePageToken, pageSize, i, err), false |
|||
} |
|||
if usePageToken { |
|||
// Pretend that we are displaying a paginated listing on the web, and the next
|
|||
// page may be served by a different process.
|
|||
// Empty page token implies done, and vice versa.
|
|||
if (it.NextPageToken() == "") != (err == done) { |
|||
return fmt.Sprintf("pageSize %d: next page token = %q and err = %v; expected empty page token iff done", |
|||
pageSize, it.NextPageToken(), err), false |
|||
} |
|||
if err == nil { |
|||
token := it.NextPageToken() |
|||
it = newIter() |
|||
it.SetPageSize(pageSize) |
|||
it.SetPageToken(token) |
|||
} |
|||
} |
|||
} |
|||
// We now should see (<zero-length or nil slice>, done), no matter how many
|
|||
// additional calls we make.
|
|||
for i := 0; i < 3; i++ { |
|||
gotPage, err := nextPage(it) |
|||
if err != done { |
|||
return fmt.Sprintf("usePageToken %v, pageSize %d, at end: got error %v, want done", |
|||
usePageToken, pageSize, err), false |
|||
} |
|||
pVal := reflect.ValueOf(gotPage) |
|||
if pVal.Kind() != reflect.Slice || pVal.Len() != 0 { |
|||
return fmt.Sprintf("usePageToken %v, pageSize %d, at end: got %+v with done, want zero-length slice", |
|||
usePageToken, pageSize, gotPage), false |
|||
} |
|||
} |
|||
} |
|||
} |
|||
return "", true |
|||
} |
@ -0,0 +1,73 @@ |
|||
/* |
|||
Copyright 2016 Google Inc. All Rights Reserved. |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
|||
*/ |
|||
|
|||
package testutil |
|||
|
|||
import ( |
|||
"net" |
|||
|
|||
grpc "google.golang.org/grpc" |
|||
) |
|||
|
|||
// A Server is an in-process gRPC server, listening on a system-chosen port on
|
|||
// the local loopback interface. Servers are for testing only and are not
|
|||
// intended to be used in production code.
|
|||
//
|
|||
// To create a server, make a new Server, register your handlers, then call
|
|||
// Start:
|
|||
//
|
|||
// srv, err := NewServer()
|
|||
// ...
|
|||
// mypb.RegisterMyServiceServer(srv.Gsrv, &myHandler)
|
|||
// ....
|
|||
// srv.Start()
|
|||
//
|
|||
// Clients should connect to the server with no security:
|
|||
//
|
|||
// conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
|
|||
// ...
|
|||
type Server struct { |
|||
Addr string |
|||
l net.Listener |
|||
Gsrv *grpc.Server |
|||
} |
|||
|
|||
// NewServer creates a new Server. The Server will be listening for gRPC connections
|
|||
// at the address named by the Addr field, without TLS.
|
|||
func NewServer() (*Server, error) { |
|||
l, err := net.Listen("tcp", "127.0.0.1:0") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
s := &Server{ |
|||
Addr: l.Addr().String(), |
|||
l: l, |
|||
Gsrv: grpc.NewServer(), |
|||
} |
|||
return s, nil |
|||
} |
|||
|
|||
// Start causes the server to start accepting incoming connections.
|
|||
// Call Start after registering handlers.
|
|||
func (s *Server) Start() { |
|||
go s.Gsrv.Serve(s.l) |
|||
} |
|||
|
|||
// Close shuts down the server.
|
|||
func (s *Server) Close() { |
|||
s.Gsrv.Stop() |
|||
s.l.Close() |
|||
} |
@ -0,0 +1,35 @@ |
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
package testutil |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
grpc "google.golang.org/grpc" |
|||
) |
|||
|
|||
func TestNewServer(t *testing.T) { |
|||
srv, err := NewServer() |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
srv.Start() |
|||
conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure()) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
conn.Close() |
|||
srv.Close() |
|||
} |
@ -0,0 +1,61 @@ |
|||
// Copyright 2015 Google Inc. All Rights Reserved.
|
|||
//
|
|||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|||
// you may not use this file except in compliance with the License.
|
|||
// You may obtain a copy of the License at
|
|||
//
|
|||
// http://www.apache.org/licenses/LICENSE-2.0
|
|||
//
|
|||
// Unless required by applicable law or agreed to in writing, software
|
|||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
// See the License for the specific language governing permissions and
|
|||
// limitations under the License.
|
|||
|
|||
package transport |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
|
|||
"golang.org/x/net/context" |
|||
"google.golang.org/api/option" |
|||
"google.golang.org/api/transport" |
|||
"google.golang.org/cloud" |
|||
"google.golang.org/grpc" |
|||
) |
|||
|
|||
// ErrHTTP is returned when on a non-200 HTTP response.
|
|||
type ErrHTTP struct { |
|||
StatusCode int |
|||
Body []byte |
|||
err error |
|||
} |
|||
|
|||
func (e *ErrHTTP) Error() string { |
|||
if e.err == nil { |
|||
return fmt.Sprintf("error during call, http status code: %v %s", e.StatusCode, e.Body) |
|||
} |
|||
return e.err.Error() |
|||
} |
|||
|
|||
// NewHTTPClient returns an HTTP client for use communicating with a Google cloud
|
|||
// service, configured with the given ClientOptions. It also returns the endpoint
|
|||
// for the service as specified in the options.
|
|||
func NewHTTPClient(ctx context.Context, opt ...cloud.ClientOption) (*http.Client, string, error) { |
|||
o := make([]option.ClientOption, 0, len(opt)) |
|||
for _, opt := range opt { |
|||
o = append(o, opt.Resolve()) |
|||
} |
|||
return transport.NewHTTPClient(ctx, o...) |
|||
} |
|||
|
|||
// DialGRPC returns a GRPC connection for use communicating with a Google cloud
|
|||
// service, configured with the given ClientOptions.
|
|||
func DialGRPC(ctx context.Context, opt ...cloud.ClientOption) (*grpc.ClientConn, error) { |
|||
o := make([]option.ClientOption, 0, len(opt)) |
|||
for _, opt := range opt { |
|||
o = append(o, opt.Resolve()) |
|||
} |
|||
return transport.DialGRPC(ctx, o...) |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue