mirror of https://github.com/matrix-org/go-neb.git
Kegan Dougal
8 years ago
76 changed files with 7942 additions and 0 deletions
-
25vendor/manifest
-
20vendor/src/github.com/cenkalti/backoff/LICENSE
-
118vendor/src/github.com/cenkalti/backoff/README.md
-
59vendor/src/github.com/cenkalti/backoff/backoff.go
-
27vendor/src/github.com/cenkalti/backoff/backoff_test.go
-
51vendor/src/github.com/cenkalti/backoff/example_test.go
-
156vendor/src/github.com/cenkalti/backoff/exponential.go
-
108vendor/src/github.com/cenkalti/backoff/exponential_test.go
-
46vendor/src/github.com/cenkalti/backoff/retry.go
-
34vendor/src/github.com/cenkalti/backoff/retry_test.go
-
79vendor/src/github.com/cenkalti/backoff/ticker.go
-
45vendor/src/github.com/cenkalti/backoff/ticker_test.go
-
37vendor/src/github.com/dghubble/go-twitter/twitter/accounts.go
-
27vendor/src/github.com/dghubble/go-twitter/twitter/accounts_test.go
-
25vendor/src/github.com/dghubble/go-twitter/twitter/backoffs.go
-
37vendor/src/github.com/dghubble/go-twitter/twitter/backoffs_test.go
-
88vendor/src/github.com/dghubble/go-twitter/twitter/demux.go
-
135vendor/src/github.com/dghubble/go-twitter/twitter/demux_test.go
-
130vendor/src/github.com/dghubble/go-twitter/twitter/direct_messages.go
-
110vendor/src/github.com/dghubble/go-twitter/twitter/direct_messages_test.go
-
70vendor/src/github.com/dghubble/go-twitter/twitter/doc.go
-
73vendor/src/github.com/dghubble/go-twitter/twitter/entities.go
-
22vendor/src/github.com/dghubble/go-twitter/twitter/entities_test.go
-
47vendor/src/github.com/dghubble/go-twitter/twitter/errors.go
-
48vendor/src/github.com/dghubble/go-twitter/twitter/errors_test.go
-
73vendor/src/github.com/dghubble/go-twitter/twitter/followers.go
-
69vendor/src/github.com/dghubble/go-twitter/twitter/followers_test.go
-
251vendor/src/github.com/dghubble/go-twitter/twitter/statuses.go
-
224vendor/src/github.com/dghubble/go-twitter/twitter/statuses_test.go
-
110vendor/src/github.com/dghubble/go-twitter/twitter/stream_messages.go
-
56vendor/src/github.com/dghubble/go-twitter/twitter/stream_utils.go
-
64vendor/src/github.com/dghubble/go-twitter/twitter/stream_utils_test.go
-
326vendor/src/github.com/dghubble/go-twitter/twitter/streams.go
-
352vendor/src/github.com/dghubble/go-twitter/twitter/streams_test.go
-
106vendor/src/github.com/dghubble/go-twitter/twitter/timelines.go
-
81vendor/src/github.com/dghubble/go-twitter/twitter/timelines_test.go
-
51vendor/src/github.com/dghubble/go-twitter/twitter/twitter.go
-
93vendor/src/github.com/dghubble/go-twitter/twitter/twitter_test.go
-
122vendor/src/github.com/dghubble/go-twitter/twitter/users.go
-
92vendor/src/github.com/dghubble/go-twitter/twitter/users_test.go
-
30vendor/src/github.com/dghubble/oauth1/CHANGES.md
-
21vendor/src/github.com/dghubble/oauth1/LICENSE
-
125vendor/src/github.com/dghubble/oauth1/README.md
-
265vendor/src/github.com/dghubble/oauth1/auther.go
-
244vendor/src/github.com/dghubble/oauth1/auther_test.go
-
165vendor/src/github.com/dghubble/oauth1/config.go
-
342vendor/src/github.com/dghubble/oauth1/config_test.go
-
24vendor/src/github.com/dghubble/oauth1/context.go
-
21vendor/src/github.com/dghubble/oauth1/context_test.go
-
97vendor/src/github.com/dghubble/oauth1/doc.go
-
13vendor/src/github.com/dghubble/oauth1/dropbox/dropbox.go
-
36vendor/src/github.com/dghubble/oauth1/encode.go
-
27vendor/src/github.com/dghubble/oauth1/encode_test.go
-
12vendor/src/github.com/dghubble/oauth1/endpoint.go
-
48vendor/src/github.com/dghubble/oauth1/examples/README.md
-
68vendor/src/github.com/dghubble/oauth1/examples/tumblr-login.go
-
37vendor/src/github.com/dghubble/oauth1/examples/tumblr-request.go
-
75vendor/src/github.com/dghubble/oauth1/examples/twitter-login.go
-
39vendor/src/github.com/dghubble/oauth1/examples/twitter-request.go
-
202vendor/src/github.com/dghubble/oauth1/reference_test.go
-
62vendor/src/github.com/dghubble/oauth1/signer.go
-
19vendor/src/github.com/dghubble/oauth1/test
-
43vendor/src/github.com/dghubble/oauth1/token.go
-
31vendor/src/github.com/dghubble/oauth1/token_test.go
-
65vendor/src/github.com/dghubble/oauth1/transport.go
-
117vendor/src/github.com/dghubble/oauth1/transport_test.go
-
13vendor/src/github.com/dghubble/oauth1/tumblr/tumblr.go
-
25vendor/src/github.com/dghubble/oauth1/twitter/twitter.go
-
52vendor/src/github.com/dghubble/sling/CHANGES.md
-
21vendor/src/github.com/dghubble/sling/LICENSE
-
273vendor/src/github.com/dghubble/sling/README.md
-
179vendor/src/github.com/dghubble/sling/doc.go
-
19vendor/src/github.com/dghubble/sling/examples/README.md
-
161vendor/src/github.com/dghubble/sling/examples/github.go
-
421vendor/src/github.com/dghubble/sling/sling.go
-
863vendor/src/github.com/dghubble/sling/sling_test.go
@ -0,0 +1,20 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2014 Cenk Altı |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy of |
|||
this software and associated documentation files (the "Software"), to deal in |
|||
the Software without restriction, including without limitation the rights to |
|||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
|||
the Software, and to permit persons to whom the Software is furnished to do so, |
|||
subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in all |
|||
copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
|||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
|||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
|||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
|||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@ -0,0 +1,118 @@ |
|||
# Exponential Backoff [![GoDoc][godoc image]][godoc] [![Build Status][travis image]][travis] [![Coverage Status][coveralls image]][coveralls] |
|||
|
|||
This is a Go port of the exponential backoff algorithm from [Google's HTTP Client Library for Java][google-http-java-client]. |
|||
|
|||
[Exponential backoff][exponential backoff wiki] |
|||
is an algorithm that uses feedback to multiplicatively decrease the rate of some process, |
|||
in order to gradually find an acceptable rate. |
|||
The retries exponentially increase and stop increasing when a certain threshold is met. |
|||
|
|||
## How To |
|||
|
|||
We define two functions, `Retry()` and `RetryNotify()`. |
|||
They receive an `Operation` to execute, a `BackOff` algorithm, |
|||
and an optional `Notify` error handler. |
|||
|
|||
The operation will be executed, and will be retried on failure with delay |
|||
as given by the backoff algorithm. The backoff algorithm can also decide when to stop |
|||
retrying. |
|||
In addition, the notify error handler will be called after each failed attempt, |
|||
except for the last time, whose error should be handled by the caller. |
|||
|
|||
```go |
|||
// An Operation is executing by Retry() or RetryNotify(). |
|||
// The operation will be retried using a backoff policy if it returns an error. |
|||
type Operation func() error |
|||
|
|||
// Notify is a notify-on-error function. It receives an operation error and |
|||
// backoff delay if the operation failed (with an error). |
|||
// |
|||
// NOTE that if the backoff policy stated to stop retrying, |
|||
// the notify function isn't called. |
|||
type Notify func(error, time.Duration) |
|||
|
|||
func Retry(Operation, BackOff) error |
|||
func RetryNotify(Operation, BackOff, Notify) |
|||
``` |
|||
|
|||
## Examples |
|||
|
|||
### Retry |
|||
|
|||
Simple retry helper that uses the default exponential backoff algorithm: |
|||
|
|||
```go |
|||
operation := func() error { |
|||
// An operation that might fail. |
|||
return nil // or return errors.New("some error") |
|||
} |
|||
|
|||
err := Retry(operation, NewExponentialBackOff()) |
|||
if err != nil { |
|||
// Handle error. |
|||
return err |
|||
} |
|||
|
|||
// Operation is successful. |
|||
return nil |
|||
``` |
|||
|
|||
### Ticker |
|||
|
|||
Ticker is for using backoff algorithms with channels. |
|||
|
|||
```go |
|||
operation := func() error { |
|||
// An operation that might fail |
|||
return nil // or return errors.New("some error") |
|||
} |
|||
|
|||
b := NewExponentialBackOff() |
|||
ticker := NewTicker(b) |
|||
|
|||
var err error |
|||
|
|||
// Ticks will continue to arrive when the previous operation is still running, |
|||
// so operations that take a while to fail could run in quick succession. |
|||
for range ticker.C { |
|||
if err = operation(); err != nil { |
|||
log.Println(err, "will retry...") |
|||
continue |
|||
} |
|||
|
|||
ticker.Stop() |
|||
break |
|||
} |
|||
|
|||
if err != nil { |
|||
// Operation has failed. |
|||
return err |
|||
} |
|||
|
|||
// Operation is successful. |
|||
return nil |
|||
``` |
|||
|
|||
## Getting Started |
|||
|
|||
```bash |
|||
# install |
|||
$ go get github.com/cenk/backoff |
|||
|
|||
# test |
|||
$ cd $GOPATH/src/github.com/cenk/backoff |
|||
$ go get -t ./... |
|||
$ go test -v -cover |
|||
``` |
|||
|
|||
[godoc]: https://godoc.org/github.com/cenk/backoff |
|||
[godoc image]: https://godoc.org/github.com/cenk/backoff?status.png |
|||
[travis]: https://travis-ci.org/cenk/backoff |
|||
[travis image]: https://travis-ci.org/cenk/backoff.png?branch=master |
|||
[coveralls]: https://coveralls.io/github/cenk/backoff?branch=master |
|||
[coveralls image]: https://coveralls.io/repos/github/cenk/backoff/badge.svg?branch=master |
|||
|
|||
[google-http-java-client]: https://github.com/google/google-http-java-client |
|||
[exponential backoff wiki]: http://en.wikipedia.org/wiki/Exponential_backoff |
|||
|
|||
[advanced example]: https://godoc.org/github.com/cenk/backoff#example_ |
@ -0,0 +1,59 @@ |
|||
// Package backoff implements backoff algorithms for retrying operations.
|
|||
//
|
|||
// Also has a Retry() helper for retrying operations that may fail.
|
|||
package backoff |
|||
|
|||
import "time" |
|||
|
|||
// BackOff is a backoff policy for retrying an operation.
|
|||
type BackOff interface { |
|||
// NextBackOff returns the duration to wait before retrying the operation,
|
|||
// or backoff.Stop to indicate that no more retries should be made.
|
|||
//
|
|||
// Example usage:
|
|||
//
|
|||
// duration := backoff.NextBackOff();
|
|||
// if (duration == backoff.Stop) {
|
|||
// // Do not retry operation.
|
|||
// } else {
|
|||
// // Sleep for duration and retry operation.
|
|||
// }
|
|||
//
|
|||
NextBackOff() time.Duration |
|||
|
|||
// Reset to initial state.
|
|||
Reset() |
|||
} |
|||
|
|||
// Indicates that no more retries should be made for use in NextBackOff().
|
|||
const Stop time.Duration = -1 |
|||
|
|||
// ZeroBackOff is a fixed backoff policy whose backoff time is always zero,
|
|||
// meaning that the operation is retried immediately without waiting, indefinitely.
|
|||
type ZeroBackOff struct{} |
|||
|
|||
func (b *ZeroBackOff) Reset() {} |
|||
|
|||
func (b *ZeroBackOff) NextBackOff() time.Duration { return 0 } |
|||
|
|||
// StopBackOff is a fixed backoff policy that always returns backoff.Stop for
|
|||
// NextBackOff(), meaning that the operation should never be retried.
|
|||
type StopBackOff struct{} |
|||
|
|||
func (b *StopBackOff) Reset() {} |
|||
|
|||
func (b *StopBackOff) NextBackOff() time.Duration { return Stop } |
|||
|
|||
// ConstantBackOff is a backoff policy that always returns the same backoff delay.
|
|||
// This is in contrast to an exponential backoff policy,
|
|||
// which returns a delay that grows longer as you call NextBackOff() over and over again.
|
|||
type ConstantBackOff struct { |
|||
Interval time.Duration |
|||
} |
|||
|
|||
func (b *ConstantBackOff) Reset() {} |
|||
func (b *ConstantBackOff) NextBackOff() time.Duration { return b.Interval } |
|||
|
|||
func NewConstantBackOff(d time.Duration) *ConstantBackOff { |
|||
return &ConstantBackOff{Interval: d} |
|||
} |
@ -0,0 +1,27 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestNextBackOffMillis(t *testing.T) { |
|||
subtestNextBackOff(t, 0, new(ZeroBackOff)) |
|||
subtestNextBackOff(t, Stop, new(StopBackOff)) |
|||
} |
|||
|
|||
func subtestNextBackOff(t *testing.T, expectedValue time.Duration, backOffPolicy BackOff) { |
|||
for i := 0; i < 10; i++ { |
|||
next := backOffPolicy.NextBackOff() |
|||
if next != expectedValue { |
|||
t.Errorf("got: %d expected: %d", next, expectedValue) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestConstantBackOff(t *testing.T) { |
|||
backoff := NewConstantBackOff(time.Second) |
|||
if backoff.NextBackOff() != time.Second { |
|||
t.Error("invalid interval") |
|||
} |
|||
} |
@ -0,0 +1,51 @@ |
|||
package backoff |
|||
|
|||
import "log" |
|||
|
|||
func ExampleRetry() error { |
|||
operation := func() error { |
|||
// An operation that might fail.
|
|||
return nil // or return errors.New("some error")
|
|||
} |
|||
|
|||
err := Retry(operation, NewExponentialBackOff()) |
|||
if err != nil { |
|||
// Handle error.
|
|||
return err |
|||
} |
|||
|
|||
// Operation is successful.
|
|||
return nil |
|||
} |
|||
|
|||
func ExampleTicker() error { |
|||
operation := func() error { |
|||
// An operation that might fail
|
|||
return nil // or return errors.New("some error")
|
|||
} |
|||
|
|||
b := NewExponentialBackOff() |
|||
ticker := NewTicker(b) |
|||
|
|||
var err error |
|||
|
|||
// Ticks will continue to arrive when the previous operation is still running,
|
|||
// so operations that take a while to fail could run in quick succession.
|
|||
for _ = range ticker.C { |
|||
if err = operation(); err != nil { |
|||
log.Println(err, "will retry...") |
|||
continue |
|||
} |
|||
|
|||
ticker.Stop() |
|||
break |
|||
} |
|||
|
|||
if err != nil { |
|||
// Operation has failed.
|
|||
return err |
|||
} |
|||
|
|||
// Operation is successful.
|
|||
return nil |
|||
} |
@ -0,0 +1,156 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"math/rand" |
|||
"time" |
|||
) |
|||
|
|||
/* |
|||
ExponentialBackOff is a backoff implementation that increases the backoff |
|||
period for each retry attempt using a randomization function that grows exponentially. |
|||
|
|||
NextBackOff() is calculated using the following formula: |
|||
|
|||
randomized interval = |
|||
RetryInterval * (random value in range [1 - RandomizationFactor, 1 + RandomizationFactor]) |
|||
|
|||
In other words NextBackOff() will range between the randomization factor |
|||
percentage below and above the retry interval. |
|||
|
|||
For example, given the following parameters: |
|||
|
|||
RetryInterval = 2 |
|||
RandomizationFactor = 0.5 |
|||
Multiplier = 2 |
|||
|
|||
the actual backoff period used in the next retry attempt will range between 1 and 3 seconds, |
|||
multiplied by the exponential, that is, between 2 and 6 seconds. |
|||
|
|||
Note: MaxInterval caps the RetryInterval and not the randomized interval. |
|||
|
|||
If the time elapsed since an ExponentialBackOff instance is created goes past the |
|||
MaxElapsedTime, then the method NextBackOff() starts returning backoff.Stop. |
|||
|
|||
The elapsed time can be reset by calling Reset(). |
|||
|
|||
Example: Given the following default arguments, for 10 tries the sequence will be, |
|||
and assuming we go over the MaxElapsedTime on the 10th try: |
|||
|
|||
Request # RetryInterval (seconds) Randomized Interval (seconds) |
|||
|
|||
1 0.5 [0.25, 0.75] |
|||
2 0.75 [0.375, 1.125] |
|||
3 1.125 [0.562, 1.687] |
|||
4 1.687 [0.8435, 2.53] |
|||
5 2.53 [1.265, 3.795] |
|||
6 3.795 [1.897, 5.692] |
|||
7 5.692 [2.846, 8.538] |
|||
8 8.538 [4.269, 12.807] |
|||
9 12.807 [6.403, 19.210] |
|||
10 19.210 backoff.Stop |
|||
|
|||
Note: Implementation is not thread-safe. |
|||
*/ |
|||
type ExponentialBackOff struct { |
|||
InitialInterval time.Duration |
|||
RandomizationFactor float64 |
|||
Multiplier float64 |
|||
MaxInterval time.Duration |
|||
// After MaxElapsedTime the ExponentialBackOff stops.
|
|||
// It never stops if MaxElapsedTime == 0.
|
|||
MaxElapsedTime time.Duration |
|||
Clock Clock |
|||
|
|||
currentInterval time.Duration |
|||
startTime time.Time |
|||
} |
|||
|
|||
// Clock is an interface that returns current time for BackOff.
|
|||
type Clock interface { |
|||
Now() time.Time |
|||
} |
|||
|
|||
// Default values for ExponentialBackOff.
|
|||
const ( |
|||
DefaultInitialInterval = 500 * time.Millisecond |
|||
DefaultRandomizationFactor = 0.5 |
|||
DefaultMultiplier = 1.5 |
|||
DefaultMaxInterval = 60 * time.Second |
|||
DefaultMaxElapsedTime = 15 * time.Minute |
|||
) |
|||
|
|||
// NewExponentialBackOff creates an instance of ExponentialBackOff using default values.
|
|||
func NewExponentialBackOff() *ExponentialBackOff { |
|||
b := &ExponentialBackOff{ |
|||
InitialInterval: DefaultInitialInterval, |
|||
RandomizationFactor: DefaultRandomizationFactor, |
|||
Multiplier: DefaultMultiplier, |
|||
MaxInterval: DefaultMaxInterval, |
|||
MaxElapsedTime: DefaultMaxElapsedTime, |
|||
Clock: SystemClock, |
|||
} |
|||
if b.RandomizationFactor < 0 { |
|||
b.RandomizationFactor = 0 |
|||
} else if b.RandomizationFactor > 1 { |
|||
b.RandomizationFactor = 1 |
|||
} |
|||
b.Reset() |
|||
return b |
|||
} |
|||
|
|||
type systemClock struct{} |
|||
|
|||
func (t systemClock) Now() time.Time { |
|||
return time.Now() |
|||
} |
|||
|
|||
// SystemClock implements Clock interface that uses time.Now().
|
|||
var SystemClock = systemClock{} |
|||
|
|||
// Reset the interval back to the initial retry interval and restarts the timer.
|
|||
func (b *ExponentialBackOff) Reset() { |
|||
b.currentInterval = b.InitialInterval |
|||
b.startTime = b.Clock.Now() |
|||
} |
|||
|
|||
// NextBackOff calculates the next backoff interval using the formula:
|
|||
// Randomized interval = RetryInterval +/- (RandomizationFactor * RetryInterval)
|
|||
func (b *ExponentialBackOff) NextBackOff() time.Duration { |
|||
// Make sure we have not gone over the maximum elapsed time.
|
|||
if b.MaxElapsedTime != 0 && b.GetElapsedTime() > b.MaxElapsedTime { |
|||
return Stop |
|||
} |
|||
defer b.incrementCurrentInterval() |
|||
return getRandomValueFromInterval(b.RandomizationFactor, rand.Float64(), b.currentInterval) |
|||
} |
|||
|
|||
// GetElapsedTime returns the elapsed time since an ExponentialBackOff instance
|
|||
// is created and is reset when Reset() is called.
|
|||
//
|
|||
// The elapsed time is computed using time.Now().UnixNano().
|
|||
func (b *ExponentialBackOff) GetElapsedTime() time.Duration { |
|||
return b.Clock.Now().Sub(b.startTime) |
|||
} |
|||
|
|||
// Increments the current interval by multiplying it with the multiplier.
|
|||
func (b *ExponentialBackOff) incrementCurrentInterval() { |
|||
// Check for overflow, if overflow is detected set the current interval to the max interval.
|
|||
if float64(b.currentInterval) >= float64(b.MaxInterval)/b.Multiplier { |
|||
b.currentInterval = b.MaxInterval |
|||
} else { |
|||
b.currentInterval = time.Duration(float64(b.currentInterval) * b.Multiplier) |
|||
} |
|||
} |
|||
|
|||
// Returns a random value from the following interval:
|
|||
// [randomizationFactor * currentInterval, randomizationFactor * currentInterval].
|
|||
func getRandomValueFromInterval(randomizationFactor, random float64, currentInterval time.Duration) time.Duration { |
|||
var delta = randomizationFactor * float64(currentInterval) |
|||
var minInterval = float64(currentInterval) - delta |
|||
var maxInterval = float64(currentInterval) + delta |
|||
|
|||
// Get a random value from the range [minInterval, maxInterval].
|
|||
// The formula used below has a +1 because if the minInterval is 1 and the maxInterval is 3 then
|
|||
// we want a 33% chance for selecting either 1, 2 or 3.
|
|||
return time.Duration(minInterval + (random * (maxInterval - minInterval + 1))) |
|||
} |
@ -0,0 +1,108 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"math" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
func TestBackOff(t *testing.T) { |
|||
var ( |
|||
testInitialInterval = 500 * time.Millisecond |
|||
testRandomizationFactor = 0.1 |
|||
testMultiplier = 2.0 |
|||
testMaxInterval = 5 * time.Second |
|||
testMaxElapsedTime = 15 * time.Minute |
|||
) |
|||
|
|||
exp := NewExponentialBackOff() |
|||
exp.InitialInterval = testInitialInterval |
|||
exp.RandomizationFactor = testRandomizationFactor |
|||
exp.Multiplier = testMultiplier |
|||
exp.MaxInterval = testMaxInterval |
|||
exp.MaxElapsedTime = testMaxElapsedTime |
|||
exp.Reset() |
|||
|
|||
var expectedResults = []time.Duration{500, 1000, 2000, 4000, 5000, 5000, 5000, 5000, 5000, 5000} |
|||
for i, d := range expectedResults { |
|||
expectedResults[i] = d * time.Millisecond |
|||
} |
|||
|
|||
for _, expected := range expectedResults { |
|||
assertEquals(t, expected, exp.currentInterval) |
|||
// Assert that the next backoff falls in the expected range.
|
|||
var minInterval = expected - time.Duration(testRandomizationFactor*float64(expected)) |
|||
var maxInterval = expected + time.Duration(testRandomizationFactor*float64(expected)) |
|||
var actualInterval = exp.NextBackOff() |
|||
if !(minInterval <= actualInterval && actualInterval <= maxInterval) { |
|||
t.Error("error") |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestGetRandomizedInterval(t *testing.T) { |
|||
// 33% chance of being 1.
|
|||
assertEquals(t, 1, getRandomValueFromInterval(0.5, 0, 2)) |
|||
assertEquals(t, 1, getRandomValueFromInterval(0.5, 0.33, 2)) |
|||
// 33% chance of being 2.
|
|||
assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.34, 2)) |
|||
assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.66, 2)) |
|||
// 33% chance of being 3.
|
|||
assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.67, 2)) |
|||
assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.99, 2)) |
|||
} |
|||
|
|||
type TestClock struct { |
|||
i time.Duration |
|||
start time.Time |
|||
} |
|||
|
|||
func (c *TestClock) Now() time.Time { |
|||
t := c.start.Add(c.i) |
|||
c.i += time.Second |
|||
return t |
|||
} |
|||
|
|||
func TestGetElapsedTime(t *testing.T) { |
|||
var exp = NewExponentialBackOff() |
|||
exp.Clock = &TestClock{} |
|||
exp.Reset() |
|||
|
|||
var elapsedTime = exp.GetElapsedTime() |
|||
if elapsedTime != time.Second { |
|||
t.Errorf("elapsedTime=%d", elapsedTime) |
|||
} |
|||
} |
|||
|
|||
func TestMaxElapsedTime(t *testing.T) { |
|||
var exp = NewExponentialBackOff() |
|||
exp.Clock = &TestClock{start: time.Time{}.Add(10000 * time.Second)} |
|||
// Change the currentElapsedTime to be 0 ensuring that the elapsed time will be greater
|
|||
// than the max elapsed time.
|
|||
exp.startTime = time.Time{} |
|||
assertEquals(t, Stop, exp.NextBackOff()) |
|||
} |
|||
|
|||
func TestBackOffOverflow(t *testing.T) { |
|||
var ( |
|||
testInitialInterval time.Duration = math.MaxInt64 / 2 |
|||
testMaxInterval time.Duration = math.MaxInt64 |
|||
testMultiplier = 2.1 |
|||
) |
|||
|
|||
exp := NewExponentialBackOff() |
|||
exp.InitialInterval = testInitialInterval |
|||
exp.Multiplier = testMultiplier |
|||
exp.MaxInterval = testMaxInterval |
|||
exp.Reset() |
|||
|
|||
exp.NextBackOff() |
|||
// Assert that when an overflow is possible the current varerval time.Duration is set to the max varerval time.Duration .
|
|||
assertEquals(t, testMaxInterval, exp.currentInterval) |
|||
} |
|||
|
|||
func assertEquals(t *testing.T, expected, value time.Duration) { |
|||
if expected != value { |
|||
t.Errorf("got: %d, expected: %d", value, expected) |
|||
} |
|||
} |
@ -0,0 +1,46 @@ |
|||
package backoff |
|||
|
|||
import "time" |
|||
|
|||
// An Operation is executing by Retry() or RetryNotify().
|
|||
// The operation will be retried using a backoff policy if it returns an error.
|
|||
type Operation func() error |
|||
|
|||
// Notify is a notify-on-error function. It receives an operation error and
|
|||
// backoff delay if the operation failed (with an error).
|
|||
//
|
|||
// NOTE that if the backoff policy stated to stop retrying,
|
|||
// the notify function isn't called.
|
|||
type Notify func(error, time.Duration) |
|||
|
|||
// Retry the operation o until it does not return error or BackOff stops.
|
|||
// o is guaranteed to be run at least once.
|
|||
// It is the caller's responsibility to reset b after Retry returns.
|
|||
//
|
|||
// Retry sleeps the goroutine for the duration returned by BackOff after a
|
|||
// failed operation returns.
|
|||
func Retry(o Operation, b BackOff) error { return RetryNotify(o, b, nil) } |
|||
|
|||
// RetryNotify calls notify function with the error and wait duration
|
|||
// for each failed attempt before sleep.
|
|||
func RetryNotify(operation Operation, b BackOff, notify Notify) error { |
|||
var err error |
|||
var next time.Duration |
|||
|
|||
b.Reset() |
|||
for { |
|||
if err = operation(); err == nil { |
|||
return nil |
|||
} |
|||
|
|||
if next = b.NextBackOff(); next == Stop { |
|||
return err |
|||
} |
|||
|
|||
if notify != nil { |
|||
notify(err, next) |
|||
} |
|||
|
|||
time.Sleep(next) |
|||
} |
|||
} |
@ -0,0 +1,34 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"errors" |
|||
"log" |
|||
"testing" |
|||
) |
|||
|
|||
func TestRetry(t *testing.T) { |
|||
const successOn = 3 |
|||
var i = 0 |
|||
|
|||
// This function is successful on "successOn" calls.
|
|||
f := func() error { |
|||
i++ |
|||
log.Printf("function is called %d. time\n", i) |
|||
|
|||
if i == successOn { |
|||
log.Println("OK") |
|||
return nil |
|||
} |
|||
|
|||
log.Println("error") |
|||
return errors.New("error") |
|||
} |
|||
|
|||
err := Retry(f, NewExponentialBackOff()) |
|||
if err != nil { |
|||
t.Errorf("unexpected error: %s", err.Error()) |
|||
} |
|||
if i != successOn { |
|||
t.Errorf("invalid number of retries: %d", i) |
|||
} |
|||
} |
@ -0,0 +1,79 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"runtime" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// Ticker holds a channel that delivers `ticks' of a clock at times reported by a BackOff.
|
|||
//
|
|||
// Ticks will continue to arrive when the previous operation is still running,
|
|||
// so operations that take a while to fail could run in quick succession.
|
|||
type Ticker struct { |
|||
C <-chan time.Time |
|||
c chan time.Time |
|||
b BackOff |
|||
stop chan struct{} |
|||
stopOnce sync.Once |
|||
} |
|||
|
|||
// NewTicker returns a new Ticker containing a channel that will send the time at times
|
|||
// specified by the BackOff argument. Ticker is guaranteed to tick at least once.
|
|||
// The channel is closed when Stop method is called or BackOff stops.
|
|||
func NewTicker(b BackOff) *Ticker { |
|||
c := make(chan time.Time) |
|||
t := &Ticker{ |
|||
C: c, |
|||
c: c, |
|||
b: b, |
|||
stop: make(chan struct{}), |
|||
} |
|||
go t.run() |
|||
runtime.SetFinalizer(t, (*Ticker).Stop) |
|||
return t |
|||
} |
|||
|
|||
// Stop turns off a ticker. After Stop, no more ticks will be sent.
|
|||
func (t *Ticker) Stop() { |
|||
t.stopOnce.Do(func() { close(t.stop) }) |
|||
} |
|||
|
|||
func (t *Ticker) run() { |
|||
c := t.c |
|||
defer close(c) |
|||
t.b.Reset() |
|||
|
|||
// Ticker is guaranteed to tick at least once.
|
|||
afterC := t.send(time.Now()) |
|||
|
|||
for { |
|||
if afterC == nil { |
|||
return |
|||
} |
|||
|
|||
select { |
|||
case tick := <-afterC: |
|||
afterC = t.send(tick) |
|||
case <-t.stop: |
|||
t.c = nil // Prevent future ticks from being sent to the channel.
|
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (t *Ticker) send(tick time.Time) <-chan time.Time { |
|||
select { |
|||
case t.c <- tick: |
|||
case <-t.stop: |
|||
return nil |
|||
} |
|||
|
|||
next := t.b.NextBackOff() |
|||
if next == Stop { |
|||
t.Stop() |
|||
return nil |
|||
} |
|||
|
|||
return time.After(next) |
|||
} |
@ -0,0 +1,45 @@ |
|||
package backoff |
|||
|
|||
import ( |
|||
"errors" |
|||
"log" |
|||
"testing" |
|||
) |
|||
|
|||
func TestTicker(t *testing.T) { |
|||
const successOn = 3 |
|||
var i = 0 |
|||
|
|||
// This function is successful on "successOn" calls.
|
|||
f := func() error { |
|||
i++ |
|||
log.Printf("function is called %d. time\n", i) |
|||
|
|||
if i == successOn { |
|||
log.Println("OK") |
|||
return nil |
|||
} |
|||
|
|||
log.Println("error") |
|||
return errors.New("error") |
|||
} |
|||
|
|||
b := NewExponentialBackOff() |
|||
ticker := NewTicker(b) |
|||
|
|||
var err error |
|||
for _ = range ticker.C { |
|||
if err = f(); err != nil { |
|||
t.Log(err) |
|||
continue |
|||
} |
|||
|
|||
break |
|||
} |
|||
if err != nil { |
|||
t.Errorf("unexpected error: %s", err.Error()) |
|||
} |
|||
if i != successOn { |
|||
t.Errorf("invalid number of retries: %d", i) |
|||
} |
|||
} |
@ -0,0 +1,37 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// AccountService provides a method for account credential verification.
|
|||
type AccountService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newAccountService returns a new AccountService.
|
|||
func newAccountService(sling *sling.Sling) *AccountService { |
|||
return &AccountService{ |
|||
sling: sling.Path("account/"), |
|||
} |
|||
} |
|||
|
|||
// AccountVerifyParams are the params for AccountService.VerifyCredentials.
|
|||
type AccountVerifyParams struct { |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
SkipStatus *bool `url:"skip_status,omitempty"` |
|||
IncludeEmail *bool `url:"include_email,omitempty"` |
|||
} |
|||
|
|||
// VerifyCredentials returns the authorized user if credentials are valid and
|
|||
// returns an error otherwise.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/get/account/verify_credentials
|
|||
func (s *AccountService) VerifyCredentials(params *AccountVerifyParams) (*User, *http.Response, error) { |
|||
user := new(User) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("verify_credentials.json").QueryStruct(params).Receive(user, apiError) |
|||
return user, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,27 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestAccountService_VerifyCredentials(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/account/verify_credentials.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"include_entities": "false", "include_email": "true"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"name": "Dalton Hubble", "id": 623265148}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
user, _, err := client.Accounts.VerifyCredentials(&AccountVerifyParams{IncludeEntities: Bool(false), IncludeEmail: Bool(true)}) |
|||
expected := &User{Name: "Dalton Hubble", ID: 623265148} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, user) |
|||
} |
@ -0,0 +1,25 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"time" |
|||
|
|||
"github.com/cenkalti/backoff" |
|||
) |
|||
|
|||
func newExponentialBackOff() *backoff.ExponentialBackOff { |
|||
b := backoff.NewExponentialBackOff() |
|||
b.InitialInterval = 5 * time.Second |
|||
b.Multiplier = 2.0 |
|||
b.MaxInterval = 320 * time.Second |
|||
b.Reset() |
|||
return b |
|||
} |
|||
|
|||
func newAggressiveExponentialBackOff() *backoff.ExponentialBackOff { |
|||
b := backoff.NewExponentialBackOff() |
|||
b.InitialInterval = 1 * time.Minute |
|||
b.Multiplier = 2.0 |
|||
b.MaxInterval = 16 * time.Minute |
|||
b.Reset() |
|||
return b |
|||
} |
@ -0,0 +1,37 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestNewExponentialBackOff(t *testing.T) { |
|||
b := newExponentialBackOff() |
|||
assert.Equal(t, 5*time.Second, b.InitialInterval) |
|||
assert.Equal(t, 2.0, b.Multiplier) |
|||
assert.Equal(t, 320*time.Second, b.MaxInterval) |
|||
} |
|||
|
|||
func TestNewAggressiveExponentialBackOff(t *testing.T) { |
|||
b := newAggressiveExponentialBackOff() |
|||
assert.Equal(t, 1*time.Minute, b.InitialInterval) |
|||
assert.Equal(t, 2.0, b.Multiplier) |
|||
assert.Equal(t, 16*time.Minute, b.MaxInterval) |
|||
} |
|||
|
|||
// BackoffRecorder is an implementation of backoff.BackOff that records
|
|||
// calls to NextBackOff and Reset for later inspection in tests.
|
|||
type BackOffRecorder struct { |
|||
Count int |
|||
} |
|||
|
|||
func (b *BackOffRecorder) NextBackOff() time.Duration { |
|||
b.Count++ |
|||
return 1 * time.Nanosecond |
|||
} |
|||
|
|||
func (b *BackOffRecorder) Reset() { |
|||
b.Count = 0 |
|||
} |
@ -0,0 +1,88 @@ |
|||
package twitter |
|||
|
|||
// A Demux receives interface{} messages individually or from a channel and
|
|||
// sends those messages to one or more outputs determined by the
|
|||
// implementation.
|
|||
type Demux interface { |
|||
Handle(message interface{}) |
|||
HandleChan(messages <-chan interface{}) |
|||
} |
|||
|
|||
// SwitchDemux receives messages and uses a type switch to send each typed
|
|||
// message to a handler function.
|
|||
type SwitchDemux struct { |
|||
All func(message interface{}) |
|||
Tweet func(tweet *Tweet) |
|||
DM func(dm *DirectMessage) |
|||
StatusDeletion func(deletion *StatusDeletion) |
|||
LocationDeletion func(LocationDeletion *LocationDeletion) |
|||
StreamLimit func(limit *StreamLimit) |
|||
StatusWithheld func(statusWithheld *StatusWithheld) |
|||
UserWithheld func(userWithheld *UserWithheld) |
|||
StreamDisconnect func(disconnect *StreamDisconnect) |
|||
Warning func(warning *StallWarning) |
|||
FriendsList func(friendsList *FriendsList) |
|||
Event func(event *Event) |
|||
Other func(message interface{}) |
|||
} |
|||
|
|||
// NewSwitchDemux returns a new SwitchMux which has NoOp handler functions.
|
|||
func NewSwitchDemux() SwitchDemux { |
|||
return SwitchDemux{ |
|||
All: func(message interface{}) {}, |
|||
Tweet: func(tweet *Tweet) {}, |
|||
DM: func(dm *DirectMessage) {}, |
|||
StatusDeletion: func(deletion *StatusDeletion) {}, |
|||
LocationDeletion: func(LocationDeletion *LocationDeletion) {}, |
|||
StreamLimit: func(limit *StreamLimit) {}, |
|||
StatusWithheld: func(statusWithheld *StatusWithheld) {}, |
|||
UserWithheld: func(userWithheld *UserWithheld) {}, |
|||
StreamDisconnect: func(disconnect *StreamDisconnect) {}, |
|||
Warning: func(warning *StallWarning) {}, |
|||
FriendsList: func(friendsList *FriendsList) {}, |
|||
Event: func(event *Event) {}, |
|||
Other: func(message interface{}) {}, |
|||
} |
|||
} |
|||
|
|||
// Handle determines the type of a message and calls the corresponding receiver
|
|||
// function with the typed message. All messages are passed to the All func.
|
|||
// Messages with unmatched types are passed to the Other func.
|
|||
func (d SwitchDemux) Handle(message interface{}) { |
|||
d.All(message) |
|||
switch msg := message.(type) { |
|||
case *Tweet: |
|||
d.Tweet(msg) |
|||
case *DirectMessage: |
|||
d.DM(msg) |
|||
case *StatusDeletion: |
|||
d.StatusDeletion(msg) |
|||
case *LocationDeletion: |
|||
d.LocationDeletion(msg) |
|||
case *StreamLimit: |
|||
d.StreamLimit(msg) |
|||
case *StatusWithheld: |
|||
d.StatusWithheld(msg) |
|||
case *UserWithheld: |
|||
d.UserWithheld(msg) |
|||
case *StreamDisconnect: |
|||
d.StreamDisconnect(msg) |
|||
case *StallWarning: |
|||
d.Warning(msg) |
|||
case *FriendsList: |
|||
d.FriendsList(msg) |
|||
case *Event: |
|||
d.Event(msg) |
|||
default: |
|||
d.Other(msg) |
|||
} |
|||
} |
|||
|
|||
// HandleChan receives messages and calls the corresponding receiver function
|
|||
// with the typed message. All messages are passed to the All func. Messages
|
|||
// with unmatched type are passed to the Other func.
|
|||
func (d SwitchDemux) HandleChan(messages <-chan interface{}) { |
|||
for message := range messages { |
|||
d.Handle(message) |
|||
} |
|||
} |
@ -0,0 +1,135 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestDemux_Handle(t *testing.T) { |
|||
messages, expectedCounts := exampleMessages() |
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
for _, message := range messages { |
|||
demux.Handle(message) |
|||
} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestDemux_HandleChan(t *testing.T) { |
|||
messages, expectedCounts := exampleMessages() |
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
ch := make(chan interface{}) |
|||
// stream messages into channel
|
|||
go func() { |
|||
for _, msg := range messages { |
|||
ch <- msg |
|||
} |
|||
close(ch) |
|||
}() |
|||
// handle channel messages until exhausted
|
|||
demux.HandleChan(ch) |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
// counter counts stream messages by type for testing.
|
|||
type counter struct { |
|||
all int |
|||
tweet int |
|||
dm int |
|||
statusDeletion int |
|||
locationDeletion int |
|||
streamLimit int |
|||
statusWithheld int |
|||
userWithheld int |
|||
streamDisconnect int |
|||
stallWarning int |
|||
friendsList int |
|||
event int |
|||
other int |
|||
} |
|||
|
|||
// newCounterDemux returns a Demux which counts message types.
|
|||
func newCounterDemux(counter *counter) Demux { |
|||
demux := NewSwitchDemux() |
|||
demux.All = func(interface{}) { |
|||
counter.all++ |
|||
} |
|||
demux.Tweet = func(*Tweet) { |
|||
counter.tweet++ |
|||
} |
|||
demux.DM = func(*DirectMessage) { |
|||
counter.dm++ |
|||
} |
|||
demux.StatusDeletion = func(*StatusDeletion) { |
|||
counter.statusDeletion++ |
|||
} |
|||
demux.LocationDeletion = func(*LocationDeletion) { |
|||
counter.locationDeletion++ |
|||
} |
|||
demux.StreamLimit = func(*StreamLimit) { |
|||
counter.streamLimit++ |
|||
} |
|||
demux.StatusWithheld = func(*StatusWithheld) { |
|||
counter.statusWithheld++ |
|||
} |
|||
demux.UserWithheld = func(*UserWithheld) { |
|||
counter.userWithheld++ |
|||
} |
|||
demux.StreamDisconnect = func(*StreamDisconnect) { |
|||
counter.streamDisconnect++ |
|||
} |
|||
demux.Warning = func(*StallWarning) { |
|||
counter.stallWarning++ |
|||
} |
|||
demux.FriendsList = func(*FriendsList) { |
|||
counter.friendsList++ |
|||
} |
|||
demux.Event = func(*Event) { |
|||
counter.event++ |
|||
} |
|||
demux.Other = func(interface{}) { |
|||
counter.other++ |
|||
} |
|||
return demux |
|||
} |
|||
|
|||
// examples messages returns a test stream of messages and the expected
|
|||
// counts of each message type.
|
|||
func exampleMessages() (messages []interface{}, expectedCounts *counter) { |
|||
var ( |
|||
tweet = &Tweet{} |
|||
dm = &DirectMessage{} |
|||
statusDeletion = &StatusDeletion{} |
|||
locationDeletion = &LocationDeletion{} |
|||
streamLimit = &StreamLimit{} |
|||
statusWithheld = &StatusWithheld{} |
|||
userWithheld = &UserWithheld{} |
|||
streamDisconnect = &StreamDisconnect{} |
|||
stallWarning = &StallWarning{} |
|||
friendsList = &FriendsList{} |
|||
event = &Event{} |
|||
otherA = func() {} |
|||
otherB = struct{}{} |
|||
) |
|||
messages = []interface{}{tweet, dm, statusDeletion, locationDeletion, |
|||
streamLimit, statusWithheld, userWithheld, streamDisconnect, |
|||
stallWarning, friendsList, event, otherA, otherB} |
|||
expectedCounts = &counter{ |
|||
all: len(messages), |
|||
tweet: 1, |
|||
dm: 1, |
|||
statusDeletion: 1, |
|||
locationDeletion: 1, |
|||
streamLimit: 1, |
|||
statusWithheld: 1, |
|||
userWithheld: 1, |
|||
streamDisconnect: 1, |
|||
stallWarning: 1, |
|||
friendsList: 1, |
|||
event: 1, |
|||
other: 2, |
|||
} |
|||
return messages, expectedCounts |
|||
} |
@ -0,0 +1,130 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// DirectMessage is a direct message to a single recipient.
|
|||
type DirectMessage struct { |
|||
CreatedAt string `json:"created_at"` |
|||
Entities *Entities `json:"entities"` |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
Recipient *User `json:"recipient"` |
|||
RecipientID int64 `json:"recipient_id"` |
|||
RecipientScreenName string `json:"recipient_screen_name"` |
|||
Sender *User `json:"sender"` |
|||
SenderID int64 `json:"sender_id"` |
|||
SenderScreenName string `json:"sender_screen_name"` |
|||
Text string `json:"text"` |
|||
} |
|||
|
|||
// DirectMessageService provides methods for accessing Twitter direct message
|
|||
// API endpoints.
|
|||
type DirectMessageService struct { |
|||
baseSling *sling.Sling |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newDirectMessageService returns a new DirectMessageService.
|
|||
func newDirectMessageService(sling *sling.Sling) *DirectMessageService { |
|||
return &DirectMessageService{ |
|||
baseSling: sling.New(), |
|||
sling: sling.Path("direct_messages/"), |
|||
} |
|||
} |
|||
|
|||
// directMessageShowParams are the parameters for DirectMessageService.Show
|
|||
type directMessageShowParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
} |
|||
|
|||
// Show returns the requested Direct Message.
|
|||
// Requires a user auth context with DM scope.
|
|||
// https://dev.twitter.com/rest/reference/get/direct_messages/show
|
|||
func (s *DirectMessageService) Show(id int64) (*DirectMessage, *http.Response, error) { |
|||
params := &directMessageShowParams{ID: id} |
|||
dm := new(DirectMessage) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("show.json").QueryStruct(params).Receive(dm, apiError) |
|||
return dm, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// DirectMessageGetParams are the parameters for DirectMessageService.Get
|
|||
type DirectMessageGetParams struct { |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
Count int `url:"count,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
SkipStatus *bool `url:"skip_status,omitempty"` |
|||
} |
|||
|
|||
// Get returns recent Direct Messages received by the authenticated user.
|
|||
// Requires a user auth context with DM scope.
|
|||
// https://dev.twitter.com/rest/reference/get/direct_messages
|
|||
func (s *DirectMessageService) Get(params *DirectMessageGetParams) ([]DirectMessage, *http.Response, error) { |
|||
dms := new([]DirectMessage) |
|||
apiError := new(APIError) |
|||
resp, err := s.baseSling.New().Get("direct_messages.json").QueryStruct(params).Receive(dms, apiError) |
|||
return *dms, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// DirectMessageSentParams are the parameters for DirectMessageService.Sent
|
|||
type DirectMessageSentParams struct { |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
Count int `url:"count,omitempty"` |
|||
Page int `url:"page,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
} |
|||
|
|||
// Sent returns recent Direct Messages sent by the authenticated user.
|
|||
// Requires a user auth context with DM scope.
|
|||
// https://dev.twitter.com/rest/reference/get/direct_messages/sent
|
|||
func (s *DirectMessageService) Sent(params *DirectMessageSentParams) ([]DirectMessage, *http.Response, error) { |
|||
dms := new([]DirectMessage) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("sent.json").QueryStruct(params).Receive(dms, apiError) |
|||
return *dms, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// DirectMessageNewParams are the parameters for DirectMessageService.New
|
|||
type DirectMessageNewParams struct { |
|||
UserID int64 `url:"user_id,omitempty"` |
|||
ScreenName string `url:"screen_name,omitempty"` |
|||
Text string `url:"text"` |
|||
} |
|||
|
|||
// New sends a new Direct Message to a specified user as the authenticated
|
|||
// user.
|
|||
// Requires a user auth context with DM scope.
|
|||
// https://dev.twitter.com/rest/reference/post/direct_messages/new
|
|||
func (s *DirectMessageService) New(params *DirectMessageNewParams) (*DirectMessage, *http.Response, error) { |
|||
dm := new(DirectMessage) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Post("new.json").BodyForm(params).Receive(dm, apiError) |
|||
return dm, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// DirectMessageDestroyParams are the parameters for DirectMessageService.Destroy
|
|||
type DirectMessageDestroyParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
} |
|||
|
|||
// Destroy deletes the Direct Message with the given id and returns it if
|
|||
// successful.
|
|||
// Requires a user auth context with DM scope.
|
|||
// https://dev.twitter.com/rest/reference/post/direct_messages/destroy
|
|||
func (s *DirectMessageService) Destroy(id int64, params *DirectMessageDestroyParams) (*DirectMessage, *http.Response, error) { |
|||
if params == nil { |
|||
params = &DirectMessageDestroyParams{} |
|||
} |
|||
params.ID = id |
|||
dm := new(DirectMessage) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Post("destroy.json").BodyForm(params).Receive(dm, apiError) |
|||
return dm, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,110 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
var ( |
|||
testDM = DirectMessage{ |
|||
ID: 240136858829479936, |
|||
Recipient: &User{ScreenName: "theSeanCook"}, |
|||
Sender: &User{ScreenName: "s0c1alm3dia"}, |
|||
Text: "hello world", |
|||
} |
|||
testDMIDStr = "240136858829479936" |
|||
testDMJSON = `{"id": 240136858829479936,"recipient": {"screen_name": "theSeanCook"},"sender": {"screen_name": "s0c1alm3dia"},"text": "hello world"}` |
|||
) |
|||
|
|||
func TestDirectMessageService_Show(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/direct_messages/show.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"id": testDMIDStr}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, testDMJSON) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
dms, _, err := client.DirectMessages.Show(testDM.ID) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, &testDM, dms) |
|||
} |
|||
|
|||
func TestDirectMessageService_Get(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/direct_messages.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"since_id": "589147592367431680", "count": "1"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[`+testDMJSON+`]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &DirectMessageGetParams{SinceID: 589147592367431680, Count: 1} |
|||
dms, _, err := client.DirectMessages.Get(params) |
|||
expected := []DirectMessage{testDM} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, dms) |
|||
} |
|||
|
|||
func TestDirectMessageService_Sent(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/direct_messages/sent.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"since_id": "589147592367431680", "count": "1"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[`+testDMJSON+`]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &DirectMessageSentParams{SinceID: 589147592367431680, Count: 1} |
|||
dms, _, err := client.DirectMessages.Sent(params) |
|||
expected := []DirectMessage{testDM} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, dms) |
|||
} |
|||
|
|||
func TestDirectMessageService_New(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/direct_messages/new.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertPostForm(t, map[string]string{"screen_name": "theseancook", "text": "hello world"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, testDMJSON) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &DirectMessageNewParams{ScreenName: "theseancook", Text: "hello world"} |
|||
dm, _, err := client.DirectMessages.New(params) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, &testDM, dm) |
|||
} |
|||
|
|||
func TestDirectMessageService_Destroy(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/direct_messages/destroy.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertPostForm(t, map[string]string{"id": testDMIDStr}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, testDMJSON) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
dm, _, err := client.DirectMessages.Destroy(testDM.ID, nil) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, &testDM, dm) |
|||
} |
@ -0,0 +1,70 @@ |
|||
/* |
|||
Package twitter provides a Client for the Twitter API. |
|||
|
|||
|
|||
The twitter package provides a Client for accessing the Twitter API. Here are |
|||
some example requests. |
|||
|
|||
// Twitter client
|
|||
client := twitter.NewClient(httpClient) |
|||
// Home Timeline
|
|||
tweets, resp, err := client.Timelines.HomeTimeline(&HomeTimelineParams{}) |
|||
// Send a Tweet
|
|||
tweet, resp, err := client.Statuses.Update("just setting up my twttr", nil) |
|||
// Status Show
|
|||
tweet, resp, err := client.Statuses.Show(585613041028431872, nil) |
|||
// User Show
|
|||
params := &twitter.UserShowParams{ScreenName: "dghubble"} |
|||
user, resp, err := client.Users.Show(params) |
|||
// Followers
|
|||
followers, resp, err := client.Followers.List(&FollowerListParams{}) |
|||
|
|||
Required parameters are passed as positional arguments. Optional parameters |
|||
are passed in a typed params struct (or pass nil). |
|||
|
|||
Authentication |
|||
|
|||
By design, the Twitter Client accepts any http.Client so user auth (OAuth1) or |
|||
application auth (OAuth2) requests can be made by using the appropriate |
|||
authenticated client. Use the https://github.com/dghubble/oauth1 and
|
|||
https://github.com/golang/oauth2 packages to obtain an http.Client which
|
|||
transparently authorizes requests. |
|||
|
|||
For example, make requests as a consumer application on behalf of a user who |
|||
has granted access, with OAuth1. |
|||
|
|||
// OAuth1
|
|||
import ( |
|||
"github.com/dghubble/go-twitter/twitter" |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
config := oauth1.NewConfig("consumerKey", "consumerSecret") |
|||
token := oauth1.NewToken("accessToken", "accessSecret") |
|||
// http.Client will automatically authorize Requests
|
|||
httpClient := config.Client(oauth1.NoContext, token) |
|||
|
|||
// twitter client
|
|||
client := twitter.NewClient(httpClient) |
|||
|
|||
If no user auth context is needed, make requests as your application with |
|||
application auth. |
|||
|
|||
// OAuth2
|
|||
import ( |
|||
"github.com/dghubble/go-twitter/twitter" |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
config := &oauth2.Config{} |
|||
token := &oauth2.Token{AccessToken: accessToken} |
|||
// http.Client will automatically authorize Requests
|
|||
httpClient := config.Client(oauth2.NoContext, token) |
|||
|
|||
// twitter client
|
|||
client := twitter.NewClient(httpClient) |
|||
|
|||
To implement Login with Twitter, see https://github.com/dghubble/gologin.
|
|||
|
|||
*/ |
|||
package twitter |
@ -0,0 +1,73 @@ |
|||
package twitter |
|||
|
|||
// Entities represent metadata and context info parsed from Twitter components.
|
|||
// https://dev.twitter.com/overview/api/entities
|
|||
// TODO: symbols
|
|||
type Entities struct { |
|||
Hashtags []HashtagEntity `json:"hashtags"` |
|||
Media []MediaEntity `json:"media"` |
|||
Urls []URLEntity `json:"urls"` |
|||
UserMentions []MentionEntity `json:"user_mentions"` |
|||
} |
|||
|
|||
// HashtagEntity represents a hashtag which has been parsed from text.
|
|||
type HashtagEntity struct { |
|||
Indices Indices `json:"indices"` |
|||
Text string `json:"text"` |
|||
} |
|||
|
|||
// URLEntity represents a URL which has been parsed from text.
|
|||
type URLEntity struct { |
|||
Indices Indices `json:"indices"` |
|||
DisplayURL string `json:"display_url"` |
|||
ExpandedURL string `json:"expanded_url"` |
|||
URL string `json:"url"` |
|||
} |
|||
|
|||
// MediaEntity represents media elements associated with a Tweet.
|
|||
// TODO: add Sizes
|
|||
type MediaEntity struct { |
|||
URLEntity |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
MediaURL string `json:"media_url"` |
|||
MediaURLHttps string `json:"media_url_https"` |
|||
SourceStatusID int64 `json:"source_status_id"` |
|||
SourceStatusIDStr string `json:"source_status_id_str"` |
|||
Type string `json:"type"` |
|||
} |
|||
|
|||
// MentionEntity represents Twitter user mentions parsed from text.
|
|||
type MentionEntity struct { |
|||
Indices Indices `json:"indices"` |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
Name string `json:"name"` |
|||
ScreenName string `json:"screen_name"` |
|||
} |
|||
|
|||
// UserEntities contain Entities parsed from User url and description fields.
|
|||
// https://dev.twitter.com/overview/api/entities-in-twitter-objects#users
|
|||
type UserEntities struct { |
|||
URL Entities `json:"url"` |
|||
Description Entities `json:"description"` |
|||
} |
|||
|
|||
// ExtendedEntity contains media information.
|
|||
// https://dev.twitter.com/overview/api/entities-in-twitter-objects#extended_entities
|
|||
type ExtendedEntity struct { |
|||
Media []MediaEntity `json:"media"` |
|||
} |
|||
|
|||
// Indices represent the start and end offsets within text.
|
|||
type Indices [2]int |
|||
|
|||
// Start returns the index at which an entity starts, inclusive.
|
|||
func (i Indices) Start() int { |
|||
return i[0] |
|||
} |
|||
|
|||
// End returns the index at which an entity ends, exclusive.
|
|||
func (i Indices) End() int { |
|||
return i[1] |
|||
} |
@ -0,0 +1,22 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestIndices(t *testing.T) { |
|||
cases := []struct { |
|||
pair Indices |
|||
expectedStart int |
|||
expectedEnd int |
|||
}{ |
|||
{Indices{}, 0, 0}, |
|||
{Indices{25, 47}, 25, 47}, |
|||
} |
|||
for _, c := range cases { |
|||
assert.Equal(t, c.expectedStart, c.pair.Start()) |
|||
assert.Equal(t, c.expectedEnd, c.pair.End()) |
|||
} |
|||
} |
@ -0,0 +1,47 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
) |
|||
|
|||
// APIError represents a Twitter API Error response
|
|||
// https://dev.twitter.com/overview/api/response-codes
|
|||
type APIError struct { |
|||
Errors []ErrorDetail `json:"errors"` |
|||
} |
|||
|
|||
// ErrorDetail represents an individual item in an APIError.
|
|||
type ErrorDetail struct { |
|||
Message string `json:"message"` |
|||
Code int `json:"code"` |
|||
} |
|||
|
|||
func (e APIError) Error() string { |
|||
if len(e.Errors) > 0 { |
|||
err := e.Errors[0] |
|||
return fmt.Sprintf("twitter: %d %v", err.Code, err.Message) |
|||
} |
|||
return "" |
|||
} |
|||
|
|||
// Empty returns true if empty. Otherwise, at least 1 error message/code is
|
|||
// present and false is returned.
|
|||
func (e APIError) Empty() bool { |
|||
if len(e.Errors) == 0 { |
|||
return true |
|||
} |
|||
return false |
|||
} |
|||
|
|||
// relevantError returns any non-nil http-related error (creating the request,
|
|||
// getting the response, decoding) if any. If the decoded apiError is non-zero
|
|||
// the apiError is returned. Otherwise, no errors occurred, returns nil.
|
|||
func relevantError(httpError error, apiError APIError) error { |
|||
if httpError != nil { |
|||
return httpError |
|||
} |
|||
if apiError.Empty() { |
|||
return nil |
|||
} |
|||
return apiError |
|||
} |
@ -0,0 +1,48 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
var errAPI = APIError{ |
|||
Errors: []ErrorDetail{ |
|||
ErrorDetail{Message: "Status is a duplicate", Code: 187}, |
|||
}, |
|||
} |
|||
var errHTTP = fmt.Errorf("unknown host") |
|||
|
|||
func TestAPIError_Error(t *testing.T) { |
|||
err := APIError{} |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "", err.Error()) |
|||
} |
|||
if assert.Error(t, errAPI) { |
|||
assert.Equal(t, "twitter: 187 Status is a duplicate", errAPI.Error()) |
|||
} |
|||
} |
|||
|
|||
func TestAPIError_Empty(t *testing.T) { |
|||
err := APIError{} |
|||
assert.True(t, err.Empty()) |
|||
assert.False(t, errAPI.Empty()) |
|||
} |
|||
|
|||
func TestRelevantError(t *testing.T) { |
|||
cases := []struct { |
|||
httpError error |
|||
apiError APIError |
|||
expected error |
|||
}{ |
|||
{nil, APIError{}, nil}, |
|||
{nil, errAPI, errAPI}, |
|||
{errHTTP, APIError{}, errHTTP}, |
|||
{errHTTP, errAPI, errHTTP}, |
|||
} |
|||
for _, c := range cases { |
|||
err := relevantError(c.httpError, c.apiError) |
|||
assert.Equal(t, c.expected, err) |
|||
} |
|||
} |
@ -0,0 +1,73 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// FollowerIDs is a cursored collection of follower ids.
|
|||
type FollowerIDs struct { |
|||
IDs []int64 `json:"ids"` |
|||
NextCursor int64 `json:"next_cursor"` |
|||
NextCursorStr string `json:"next_cursor_str"` |
|||
PreviousCursor int64 `json:"previous_cursor"` |
|||
PreviousCursorStr string `json:"previous_cursor_str"` |
|||
} |
|||
|
|||
// Followers is a cursored collection of followers.
|
|||
type Followers struct { |
|||
Users []User `json:"users"` |
|||
NextCursor int64 `json:"next_cursor"` |
|||
NextCursorStr string `json:"next_cursor_str"` |
|||
PreviousCursor int64 `json:"previous_cursor"` |
|||
PreviousCursorStr string `json:"previous_cursor_str"` |
|||
} |
|||
|
|||
// FollowerService provides methods for accessing Twitter followers endpoints.
|
|||
type FollowerService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newFollowerService returns a new FollowerService.
|
|||
func newFollowerService(sling *sling.Sling) *FollowerService { |
|||
return &FollowerService{ |
|||
sling: sling.Path("followers/"), |
|||
} |
|||
} |
|||
|
|||
// FollowerIDParams are the parameters for FollowerService.Ids
|
|||
type FollowerIDParams struct { |
|||
UserID int64 `url:"user_id,omitempty"` |
|||
ScreenName string `url:"screen_name,omitempty"` |
|||
Cursor int64 `url:"cursor,omitempty"` |
|||
Count int `url:"count,omitempty"` |
|||
} |
|||
|
|||
// IDs returns a cursored collection of user ids following the specified user.
|
|||
// https://dev.twitter.com/rest/reference/get/followers/ids
|
|||
func (s *FollowerService) IDs(params *FollowerIDParams) (*FollowerIDs, *http.Response, error) { |
|||
ids := new(FollowerIDs) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("ids.json").QueryStruct(params).Receive(ids, apiError) |
|||
return ids, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// FollowerListParams are the parameters for FollowerService.List
|
|||
type FollowerListParams struct { |
|||
UserID int64 `url:"user_id,omitempty"` |
|||
ScreenName string `url:"screen_name,omitempty"` |
|||
Cursor int `url:"cursor,omitempty"` |
|||
Count int `url:"count,omitempty"` |
|||
SkipStatus *bool `url:"skip_status,omitempty"` |
|||
IncludeUserEntities *bool `url:"include_user_entities,omitempty"` |
|||
} |
|||
|
|||
// List returns a cursored collection of Users following the specified user.
|
|||
// https://dev.twitter.com/rest/reference/get/followers/list
|
|||
func (s *FollowerService) List(params *FollowerListParams) (*Followers, *http.Response, error) { |
|||
followers := new(Followers) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("list.json").QueryStruct(params).Receive(followers, apiError) |
|||
return followers, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,69 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestFollowerService_Ids(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/followers/ids.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"user_id": "623265148", "count": "5", "cursor": "1516933260114270762"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"ids":[178082406,3318241001,1318020818,191714329,376703838],"next_cursor":1516837838944119498,"next_cursor_str":"1516837838944119498","previous_cursor":-1516924983503961435,"previous_cursor_str":"-1516924983503961435"}`) |
|||
}) |
|||
expected := &FollowerIDs{ |
|||
IDs: []int64{178082406, 3318241001, 1318020818, 191714329, 376703838}, |
|||
NextCursor: 1516837838944119498, |
|||
NextCursorStr: "1516837838944119498", |
|||
PreviousCursor: -1516924983503961435, |
|||
PreviousCursorStr: "-1516924983503961435", |
|||
} |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &FollowerIDParams{ |
|||
UserID: 623265148, |
|||
Count: 5, |
|||
Cursor: 1516933260114270762, |
|||
} |
|||
followerIDs, _, err := client.Followers.IDs(params) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, followerIDs) |
|||
} |
|||
|
|||
func TestFollowerService_List(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/followers/list.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"screen_name": "dghubble", "count": "5", "cursor": "1516933260114270762", "skip_status": "true", "include_user_entities": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"users": [{"id": 123}], "next_cursor":1516837838944119498,"next_cursor_str":"1516837838944119498","previous_cursor":-1516924983503961435,"previous_cursor_str":"-1516924983503961435"}`) |
|||
}) |
|||
expected := &Followers{ |
|||
Users: []User{User{ID: 123}}, |
|||
NextCursor: 1516837838944119498, |
|||
NextCursorStr: "1516837838944119498", |
|||
PreviousCursor: -1516924983503961435, |
|||
PreviousCursorStr: "-1516924983503961435", |
|||
} |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &FollowerListParams{ |
|||
ScreenName: "dghubble", |
|||
Count: 5, |
|||
Cursor: 1516933260114270762, |
|||
SkipStatus: Bool(true), |
|||
IncludeUserEntities: Bool(false), |
|||
} |
|||
followers, _, err := client.Followers.List(params) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, followers) |
|||
} |
@ -0,0 +1,251 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// Tweet represents a Twitter Tweet, previously called a status.
|
|||
// https://dev.twitter.com/overview/api/tweets
|
|||
// Unused or deprecated fields not provided: Geo, Annotations
|
|||
type Tweet struct { |
|||
Contributors []Contributor `json:"contributors"` |
|||
Coordinates *Coordinates `json:"coordinates"` |
|||
CreatedAt string `json:"created_at"` |
|||
CurrentUserRetweet *TweetIdentifier `json:"current_user_retweet"` |
|||
Entities *Entities `json:"entities"` |
|||
FavoriteCount int `json:"favorite_count"` |
|||
Favorited bool `json:"favorited"` |
|||
FilterLevel string `json:"filter_level"` |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
InReplyToScreenName string `json:"in_reply_to_screen_name"` |
|||
InReplyToStatusID int64 `json:"in_reply_to_status_id"` |
|||
InReplyToStatusIDStr string `json:"in_reply_to_status_id_str"` |
|||
InReplyToUserID int64 `json:"in_reply_to_user_id"` |
|||
InReplyToUserIDStr string `json:"in_reply_to_user_id_str"` |
|||
Lang string `json:"lang"` |
|||
PossiblySensitive bool `json:"possibly_sensitive"` |
|||
RetweetCount int `json:"retweet_count"` |
|||
Retweeted bool `json:"retweeted"` |
|||
RetweetedStatus *Tweet `json:"retweeted_status"` |
|||
Source string `json:"source"` |
|||
Scopes map[string]interface{} `json:"scopes"` |
|||
Text string `json:"text"` |
|||
Place *Place `json:"place"` |
|||
Truncated bool `json:"truncated"` |
|||
User *User `json:"user"` |
|||
WithheldCopyright bool `json:"withheld_copyright"` |
|||
WithheldInCountries []string `json:"withheld_in_countries"` |
|||
WithheldScope string `json:"withheld_scope"` |
|||
ExtendedEntities *ExtendedEntity `json:"extended_entities"` |
|||
QuotedStatusID int64 `json:"quoted_status_id"` |
|||
QuotedStatusIDStr string `json:"quoted_status_id_str"` |
|||
QuotedStatus *Tweet `json:"quoted_status"` |
|||
} |
|||
|
|||
// Place represents a Twitter Place / Location
|
|||
// https://dev.twitter.com/overview/api/places
|
|||
type Place struct { |
|||
Attributes map[string]string `json:"attributes"` |
|||
BoundingBox *BoundingBox `json:"bounding_box"` |
|||
Country string `json:"country"` |
|||
CountryCode string `json:"country_code"` |
|||
FullName string `json:"full_name"` |
|||
Geometry *BoundingBox `json:"geometry"` |
|||
ID string `json:"id"` |
|||
Name string `json:"name"` |
|||
PlaceType string `json:"place_type"` |
|||
Polylines []string `json:"polylines"` |
|||
URL string `json:"url"` |
|||
} |
|||
|
|||
// BoundingBox represents the bounding coordinates (longitude, latitutde)
|
|||
// defining the bounds of a box containing a Place entity.
|
|||
type BoundingBox struct { |
|||
Coordinates [][][2]float64 `json:"coordinates"` |
|||
Type string `json:"type"` |
|||
} |
|||
|
|||
// Contributor represents a brief summary of a User identifiers.
|
|||
type Contributor struct { |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
ScreenName string `json:"screen_name"` |
|||
} |
|||
|
|||
// Coordinates are pairs of longitude and latitude locations.
|
|||
type Coordinates struct { |
|||
Coordinates [2]float64 `json:"coordinates"` |
|||
Type string `json:"type"` |
|||
} |
|||
|
|||
// TweetIdentifier represents the id by which a Tweet can be identified.
|
|||
type TweetIdentifier struct { |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
} |
|||
|
|||
// StatusService provides methods for accessing Twitter status API endpoints.
|
|||
type StatusService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newStatusService returns a new StatusService.
|
|||
func newStatusService(sling *sling.Sling) *StatusService { |
|||
return &StatusService{ |
|||
sling: sling.Path("statuses/"), |
|||
} |
|||
} |
|||
|
|||
// StatusShowParams are the parameters for StatusService.Show
|
|||
type StatusShowParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
IncludeMyRetweet *bool `url:"include_my_retweet,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
} |
|||
|
|||
// Show returns the requested Tweet.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/show/%3Aid
|
|||
func (s *StatusService) Show(id int64, params *StatusShowParams) (*Tweet, *http.Response, error) { |
|||
if params == nil { |
|||
params = &StatusShowParams{} |
|||
} |
|||
params.ID = id |
|||
tweet := new(Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("show.json").QueryStruct(params).Receive(tweet, apiError) |
|||
return tweet, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// StatusLookupParams are the parameters for StatusService.Lookup
|
|||
type StatusLookupParams struct { |
|||
ID []int64 `url:"id,omitempty,comma"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
Map *bool `url:"map,omitempty"` |
|||
} |
|||
|
|||
// Lookup returns the requested Tweets as a slice. Combines ids from the
|
|||
// required ids argument and from params.Id.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/lookup
|
|||
func (s *StatusService) Lookup(ids []int64, params *StatusLookupParams) ([]Tweet, *http.Response, error) { |
|||
if params == nil { |
|||
params = &StatusLookupParams{} |
|||
} |
|||
params.ID = append(params.ID, ids...) |
|||
tweets := new([]Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("lookup.json").QueryStruct(params).Receive(tweets, apiError) |
|||
return *tweets, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// StatusUpdateParams are the parameters for StatusService.Update
|
|||
type StatusUpdateParams struct { |
|||
Status string `url:"status,omitempty"` |
|||
InReplyToStatusID int64 `url:"in_reply_to_status_id,omitempty"` |
|||
PossiblySensitive *bool `url:"possibly_sensitive,omitempty"` |
|||
Lat *float64 `url:"lat,omitempty"` |
|||
Long *float64 `url:"long,omitempty"` |
|||
PlaceID string `url:"place_id,omitempty"` |
|||
DisplayCoordinates *bool `url:"display_coordinates,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
MediaIds []int64 `url:"media_ids,omitempty,comma"` |
|||
} |
|||
|
|||
// Update updates the user's status, also known as Tweeting.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/post/statuses/update
|
|||
func (s *StatusService) Update(status string, params *StatusUpdateParams) (*Tweet, *http.Response, error) { |
|||
if params == nil { |
|||
params = &StatusUpdateParams{} |
|||
} |
|||
params.Status = status |
|||
tweet := new(Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Post("update.json").BodyForm(params).Receive(tweet, apiError) |
|||
return tweet, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// StatusRetweetParams are the parameters for StatusService.Retweet
|
|||
type StatusRetweetParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
} |
|||
|
|||
// Retweet retweets the Tweet with the given id and returns the original Tweet
|
|||
// with embedded retweet details.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/post/statuses/retweet/%3Aid
|
|||
func (s *StatusService) Retweet(id int64, params *StatusRetweetParams) (*Tweet, *http.Response, error) { |
|||
if params == nil { |
|||
params = &StatusRetweetParams{} |
|||
} |
|||
params.ID = id |
|||
tweet := new(Tweet) |
|||
apiError := new(APIError) |
|||
path := fmt.Sprintf("retweet/%d.json", params.ID) |
|||
resp, err := s.sling.New().Post(path).BodyForm(params).Receive(tweet, apiError) |
|||
return tweet, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// StatusDestroyParams are the parameters for StatusService.Destroy
|
|||
type StatusDestroyParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
} |
|||
|
|||
// Destroy deletes the Tweet with the given id and returns it if successful.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/post/statuses/destroy/%3Aid
|
|||
func (s *StatusService) Destroy(id int64, params *StatusDestroyParams) (*Tweet, *http.Response, error) { |
|||
if params == nil { |
|||
params = &StatusDestroyParams{} |
|||
} |
|||
params.ID = id |
|||
tweet := new(Tweet) |
|||
apiError := new(APIError) |
|||
path := fmt.Sprintf("destroy/%d.json", params.ID) |
|||
resp, err := s.sling.New().Post(path).BodyForm(params).Receive(tweet, apiError) |
|||
return tweet, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// OEmbedTweet represents a Tweet in oEmbed format.
|
|||
type OEmbedTweet struct { |
|||
URL string `json:"url"` |
|||
ProviderURL string `json:"provider_url"` |
|||
ProviderName string `json:"provider_name"` |
|||
AuthorName string `json:"author_name"` |
|||
Version string `json:"version"` |
|||
AuthorURL string `json:"author_url"` |
|||
Type string `json:"type"` |
|||
HTML string `json:"html"` |
|||
Height int64 `json:"height"` |
|||
Width int64 `json:"width"` |
|||
CacheAge string `json:"cache_age"` |
|||
} |
|||
|
|||
// StatusOEmbedParams are the parameters for StatusService.OEmbed
|
|||
type StatusOEmbedParams struct { |
|||
ID int64 `url:"id,omitempty"` |
|||
URL string `url:"url,omitempty"` |
|||
Align string `url:"align,omitempty"` |
|||
MaxWidth int64 `url:"maxwidth,omitempty"` |
|||
HideMedia *bool `url:"hide_media,omitempty"` |
|||
HideThread *bool `url:"hide_media,omitempty"` |
|||
OmitScript *bool `url:"hide_media,omitempty"` |
|||
WidgetType string `url:"widget_type,omitempty"` |
|||
HideTweet *bool `url:"hide_tweet,omitempty"` |
|||
} |
|||
|
|||
// OEmbed returns the requested Tweet in oEmbed format.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/oembed
|
|||
func (s *StatusService) OEmbed(params *StatusOEmbedParams) (*OEmbedTweet, *http.Response, error) { |
|||
oEmbedTweet := new(OEmbedTweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("oembed.json").QueryStruct(params).Receive(oEmbedTweet, apiError) |
|||
return oEmbedTweet, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,224 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"strings" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestStatusService_Show(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/show.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"id": "589488862814076930", "include_entities": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"user": {"screen_name": "dghubble"}, "text": ".@audreyr use a DONTREADME file if you really want people to read it :P"}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusShowParams{ID: 5441, IncludeEntities: Bool(false)} |
|||
tweet, _, err := client.Statuses.Show(589488862814076930, params) |
|||
expected := &Tweet{User: &User{ScreenName: "dghubble"}, Text: ".@audreyr use a DONTREADME file if you really want people to read it :P"} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweet) |
|||
} |
|||
|
|||
func TestStatusService_ShowHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/show.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertQuery(t, map[string]string{"id": "589488862814076930"}, r) |
|||
}) |
|||
client := NewClient(httpClient) |
|||
client.Statuses.Show(589488862814076930, nil) |
|||
} |
|||
|
|||
func TestStatusService_Lookup(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/lookup.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"id": "20,573893817000140800", "trim_user": "true"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"id": 20, "text": "just setting up my twttr"}, {"id": 573893817000140800, "text": "Don't get lost #PaxEast2015"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusLookupParams{ID: []int64{20}, TrimUser: Bool(true)} |
|||
tweets, _, err := client.Statuses.Lookup([]int64{573893817000140800}, params) |
|||
expected := []Tweet{Tweet{ID: 20, Text: "just setting up my twttr"}, Tweet{ID: 573893817000140800, Text: "Don't get lost #PaxEast2015"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweets) |
|||
} |
|||
|
|||
func TestStatusService_LookupHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/1.1/statuses/lookup.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertQuery(t, map[string]string{"id": "20,573893817000140800"}, r) |
|||
}) |
|||
client := NewClient(httpClient) |
|||
client.Statuses.Lookup([]int64{20, 573893817000140800}, nil) |
|||
} |
|||
|
|||
func TestStatusService_Update(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/update.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{}, r) |
|||
assertPostForm(t, map[string]string{"status": "very informative tweet", "media_ids": "123456789,987654321", "lat": "37.826706", "long": "-122.42219"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"id": 581980947630845953, "text": "very informative tweet"}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusUpdateParams{MediaIds: []int64{123456789, 987654321}, Lat: Float(37.826706), Long: Float(-122.422190)} |
|||
tweet, _, err := client.Statuses.Update("very informative tweet", params) |
|||
expected := &Tweet{ID: 581980947630845953, Text: "very informative tweet"} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweet) |
|||
} |
|||
|
|||
func TestStatusService_UpdateHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/1.1/statuses/update.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertPostForm(t, map[string]string{"status": "very informative tweet"}, r) |
|||
}) |
|||
client := NewClient(httpClient) |
|||
client.Statuses.Update("very informative tweet", nil) |
|||
} |
|||
|
|||
func TestStatusService_APIError(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/1.1/statuses/update.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertPostForm(t, map[string]string{"status": "very informative tweet"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.WriteHeader(403) |
|||
fmt.Fprintf(w, `{"errors": [{"message": "Status is a duplicate", "code": 187}]}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
_, _, err := client.Statuses.Update("very informative tweet", nil) |
|||
expected := APIError{ |
|||
Errors: []ErrorDetail{ |
|||
ErrorDetail{Message: "Status is a duplicate", Code: 187}, |
|||
}, |
|||
} |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, expected, err) |
|||
} |
|||
} |
|||
|
|||
func TestStatusService_HTTPError(t *testing.T) { |
|||
httpClient, _, server := testServer() |
|||
server.Close() |
|||
client := NewClient(httpClient) |
|||
_, _, err := client.Statuses.Update("very informative tweet", nil) |
|||
if err == nil || !strings.Contains(err.Error(), "connection refused") { |
|||
t.Errorf("Statuses.Update error expected connection refused, got: \n %+v", err) |
|||
} |
|||
} |
|||
|
|||
func TestStatusService_Retweet(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/retweet/20.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{}, r) |
|||
assertPostForm(t, map[string]string{"id": "20", "trim_user": "true"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"id": 581980947630202020, "text": "RT @jack: just setting up my twttr", "retweeted_status": {"id": 20, "text": "just setting up my twttr"}}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusRetweetParams{TrimUser: Bool(true)} |
|||
tweet, _, err := client.Statuses.Retweet(20, params) |
|||
expected := &Tweet{ID: 581980947630202020, Text: "RT @jack: just setting up my twttr", RetweetedStatus: &Tweet{ID: 20, Text: "just setting up my twttr"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweet) |
|||
} |
|||
|
|||
func TestStatusService_RetweetHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/retweet/20.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertPostForm(t, map[string]string{"id": "20"}, r) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
client.Statuses.Retweet(20, nil) |
|||
} |
|||
|
|||
func TestStatusService_Destroy(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/destroy/40.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{}, r) |
|||
assertPostForm(t, map[string]string{"id": "40", "trim_user": "true"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"id": 40, "text": "wishing I had another sammich"}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusDestroyParams{TrimUser: Bool(true)} |
|||
tweet, _, err := client.Statuses.Destroy(40, params) |
|||
// feed Biz Stone a sammich, he deletes sammich Tweet
|
|||
expected := &Tweet{ID: 40, Text: "wishing I had another sammich"} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweet) |
|||
} |
|||
|
|||
func TestStatusService_DestroyHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/destroy/40.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertPostForm(t, map[string]string{"id": "40"}, r) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
client.Statuses.Destroy(40, nil) |
|||
} |
|||
|
|||
func TestStatusService_OEmbed(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/oembed.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"id": "691076766878691329", "maxwidth": "400", "hide_media": "true"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
// abbreviated oEmbed response
|
|||
fmt.Fprintf(w, `{"url": "https://twitter.com/dghubble/statuses/691076766878691329", "width": 400, "html": "<blockquote></blockquote>"}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
params := &StatusOEmbedParams{ |
|||
ID: 691076766878691329, |
|||
MaxWidth: 400, |
|||
HideMedia: Bool(true), |
|||
} |
|||
oembed, _, err := client.Statuses.OEmbed(params) |
|||
expected := &OEmbedTweet{ |
|||
URL: "https://twitter.com/dghubble/statuses/691076766878691329", |
|||
Width: 400, |
|||
HTML: "<blockquote></blockquote>", |
|||
} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, oembed) |
|||
} |
@ -0,0 +1,110 @@ |
|||
package twitter |
|||
|
|||
// StatusDeletion indicates that a given Tweet has been deleted.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#status_deletion_notices_delete
|
|||
type StatusDeletion struct { |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
UserID int64 `json:"user_id"` |
|||
UserIDStr string `json:"user_id_str"` |
|||
} |
|||
|
|||
type statusDeletionNotice struct { |
|||
Delete struct { |
|||
StatusDeletion *StatusDeletion `json:"status"` |
|||
} `json:"delete"` |
|||
} |
|||
|
|||
// LocationDeletion indicates geolocation data must be stripped from a range
|
|||
// of Tweets.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#Location_deletion_notices_scrub_geo
|
|||
type LocationDeletion struct { |
|||
UserID int64 `json:"user_id"` |
|||
UserIDStr string `json:"user_id_str"` |
|||
UpToStatusID int64 `json:"up_to_status_id"` |
|||
UpToStatusIDStr string `json:"up_to_status_id_str"` |
|||
} |
|||
|
|||
type locationDeletionNotice struct { |
|||
ScrubGeo *LocationDeletion `json:"scrub_geo"` |
|||
} |
|||
|
|||
// StreamLimit indicates a stream matched more statuses than its rate limit
|
|||
// allowed. The track number is the number of undelivered matches.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#limit_notices
|
|||
type StreamLimit struct { |
|||
Track int64 `json:"track"` |
|||
} |
|||
|
|||
type streamLimitNotice struct { |
|||
Limit *StreamLimit `json:"limit"` |
|||
} |
|||
|
|||
// StatusWithheld indicates a Tweet with the given ID, belonging to UserId,
|
|||
// has been withheld in certain countries.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#withheld_content_notices
|
|||
type StatusWithheld struct { |
|||
ID int64 `json:"id"` |
|||
UserID int64 `json:"user_id"` |
|||
WithheldInCountries []string `json:"withheld_in_countries"` |
|||
} |
|||
|
|||
type statusWithheldNotice struct { |
|||
StatusWithheld *StatusWithheld `json:"status_withheld"` |
|||
} |
|||
|
|||
// UserWithheld indicates a User with the given ID has been withheld in
|
|||
// certain countries.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#withheld_content_notices
|
|||
type UserWithheld struct { |
|||
ID int64 `json:"id"` |
|||
WithheldInCountries []string `json:"withheld_in_countries"` |
|||
} |
|||
type userWithheldNotice struct { |
|||
UserWithheld *UserWithheld `json:"user_withheld"` |
|||
} |
|||
|
|||
// StreamDisconnect indicates the stream has been shutdown for some reason.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#disconnect_messages
|
|||
type StreamDisconnect struct { |
|||
Code int64 `json:"code"` |
|||
StreamName string `json:"stream_name"` |
|||
Reason string `json:"reason"` |
|||
} |
|||
|
|||
type streamDisconnectNotice struct { |
|||
StreamDisconnect *StreamDisconnect `json:"disconnect"` |
|||
} |
|||
|
|||
// StallWarning indicates the client is falling behind in the stream.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#stall_warnings
|
|||
type StallWarning struct { |
|||
Code string `json:"code"` |
|||
Message string `json:"message"` |
|||
PercentFull int `json:"percent_full"` |
|||
} |
|||
|
|||
type stallWarningNotice struct { |
|||
StallWarning *StallWarning `json:"warning"` |
|||
} |
|||
|
|||
// FriendsList is a list of some of a user's friends.
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#friends_list_friends
|
|||
type FriendsList struct { |
|||
Friends []int64 `json:"friends"` |
|||
} |
|||
|
|||
type directMessageNotice struct { |
|||
DirectMessage *DirectMessage `json:"direct_message"` |
|||
} |
|||
|
|||
// Event is a non-Tweet notification message (e.g. like, retweet, follow).
|
|||
// https://dev.twitter.com/streaming/overview/messages-types#Events_event
|
|||
type Event struct { |
|||
Event string `json:"event"` |
|||
CreatedAt string `json:"created_at"` |
|||
Target *User `json:"target"` |
|||
Source *User `json:"source"` |
|||
// TODO: add List or deprecate it
|
|||
TargetObject *Tweet `json:"target_object"` |
|||
} |
@ -0,0 +1,56 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
// stopped returns true if the done channel receives, false otherwise.
|
|||
func stopped(done <-chan struct{}) bool { |
|||
select { |
|||
case <-done: |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
// sleepOrDone pauses the current goroutine until the done channel receives
|
|||
// or until at least the duration d has elapsed, whichever comes first. This
|
|||
// is similar to time.Sleep(d), except it can be interrupted.
|
|||
func sleepOrDone(d time.Duration, done <-chan struct{}) { |
|||
select { |
|||
case <-time.After(d): |
|||
return |
|||
case <-done: |
|||
return |
|||
} |
|||
} |
|||
|
|||
// scanLines is a split function for a Scanner that returns each line of text
|
|||
// stripped of the end-of-line marker "\r\n" used by Twitter Streaming APIs.
|
|||
// This differs from the bufio.ScanLines split function which considers the
|
|||
// '\r' optional.
|
|||
// https://dev.twitter.com/streaming/overview/processing
|
|||
func scanLines(data []byte, atEOF bool) (advance int, token []byte, err error) { |
|||
if atEOF && len(data) == 0 { |
|||
return 0, nil, nil |
|||
} |
|||
if i := strings.Index(string(data), "\r\n"); i >= 0 { |
|||
// We have a full '\r\n' terminated line.
|
|||
return i + 2, data[0:i], nil |
|||
} |
|||
// If we're at EOF, we have a final, non-terminated line. Return it.
|
|||
if atEOF { |
|||
return len(data), dropCR(data), nil |
|||
} |
|||
// Request more data.
|
|||
return 0, nil, nil |
|||
} |
|||
|
|||
func dropCR(data []byte) []byte { |
|||
if len(data) > 0 && data[len(data)-1] == '\n' { |
|||
return data[0 : len(data)-1] |
|||
} |
|||
return data |
|||
} |
@ -0,0 +1,64 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestStopped(t *testing.T) { |
|||
done := make(chan struct{}) |
|||
assert.False(t, stopped(done)) |
|||
close(done) |
|||
assert.True(t, stopped(done)) |
|||
} |
|||
|
|||
func TestSleepOrDone_Sleep(t *testing.T) { |
|||
wait := time.Nanosecond * 20 |
|||
done := make(chan struct{}) |
|||
completed := make(chan struct{}) |
|||
go func() { |
|||
sleepOrDone(wait, done) |
|||
close(completed) |
|||
}() |
|||
// wait for goroutine SleepOrDone to sleep
|
|||
assertDone(t, completed, defaultTestTimeout) |
|||
} |
|||
|
|||
func TestSleepOrDone_Done(t *testing.T) { |
|||
wait := time.Second * 5 |
|||
done := make(chan struct{}) |
|||
completed := make(chan struct{}) |
|||
go func() { |
|||
sleepOrDone(wait, done) |
|||
close(completed) |
|||
}() |
|||
// close done, interrupting SleepOrDone
|
|||
close(done) |
|||
// assert that SleepOrDone exited, closing completed
|
|||
assertDone(t, completed, defaultTestTimeout) |
|||
} |
|||
|
|||
func TestScanLines(t *testing.T) { |
|||
cases := []struct { |
|||
input []byte |
|||
atEOF bool |
|||
advance int |
|||
token []byte |
|||
}{ |
|||
{[]byte("Line 1\r\n"), false, 8, []byte("Line 1")}, |
|||
{[]byte("Line 1\n"), false, 0, nil}, |
|||
{[]byte("Line 1"), false, 0, nil}, |
|||
{[]byte(""), false, 0, nil}, |
|||
{[]byte("Line 1\r\n"), true, 8, []byte("Line 1")}, |
|||
{[]byte("Line 1\n"), true, 7, []byte("Line 1")}, |
|||
{[]byte("Line 1"), true, 6, []byte("Line 1")}, |
|||
{[]byte(""), true, 0, nil}, |
|||
} |
|||
for _, c := range cases { |
|||
advance, token, _ := scanLines(c.input, c.atEOF) |
|||
assert.Equal(t, c.advance, advance) |
|||
assert.Equal(t, c.token, token) |
|||
} |
|||
} |
@ -0,0 +1,326 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"bufio" |
|||
"encoding/json" |
|||
"io" |
|||
"net/http" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/cenkalti/backoff" |
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
const ( |
|||
userAgent = "go-twitter v0.1" |
|||
publicStream = "https://stream.twitter.com/1.1/" |
|||
userStream = "https://userstream.twitter.com/1.1/" |
|||
siteStream = "https://sitestream.twitter.com/1.1/" |
|||
) |
|||
|
|||
// StreamService provides methods for accessing the Twitter Streaming API.
|
|||
type StreamService struct { |
|||
client *http.Client |
|||
public *sling.Sling |
|||
user *sling.Sling |
|||
site *sling.Sling |
|||
} |
|||
|
|||
// newStreamService returns a new StreamService.
|
|||
func newStreamService(client *http.Client, sling *sling.Sling) *StreamService { |
|||
sling.Set("User-Agent", userAgent) |
|||
return &StreamService{ |
|||
client: client, |
|||
public: sling.New().Base(publicStream).Path("statuses/"), |
|||
user: sling.New().Base(userStream), |
|||
site: sling.New().Base(siteStream), |
|||
} |
|||
} |
|||
|
|||
// StreamFilterParams are parameters for StreamService.Filter.
|
|||
type StreamFilterParams struct { |
|||
FilterLevel string `url:"filter_level,omitempty"` |
|||
Follow []string `url:"follow,omitempty,comma"` |
|||
Language []string `url:"language,omitempty,comma"` |
|||
Locations []string `url:"locations,omitempty,comma"` |
|||
StallWarnings *bool `url:"stall_warnings,omitempty"` |
|||
Track []string `url:"track,omitempty,comma"` |
|||
} |
|||
|
|||
// Filter returns messages that match one or more filter predicates.
|
|||
// https://dev.twitter.com/streaming/reference/post/statuses/filter
|
|||
func (srv *StreamService) Filter(params *StreamFilterParams) (*Stream, error) { |
|||
req, err := srv.public.New().Post("filter.json").QueryStruct(params).Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newStream(srv.client, req), nil |
|||
} |
|||
|
|||
// StreamSampleParams are the parameters for StreamService.Sample.
|
|||
type StreamSampleParams struct { |
|||
StallWarnings *bool `url:"stall_warnings,omitempty"` |
|||
} |
|||
|
|||
// Sample returns a small sample of public stream messages.
|
|||
// https://dev.twitter.com/streaming/reference/get/statuses/sample
|
|||
func (srv *StreamService) Sample(params *StreamSampleParams) (*Stream, error) { |
|||
req, err := srv.public.New().Get("sample.json").QueryStruct(params).Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newStream(srv.client, req), nil |
|||
} |
|||
|
|||
// StreamUserParams are the parameters for StreamService.User.
|
|||
type StreamUserParams struct { |
|||
FilterLevel string `url:"filter_level,omitempty"` |
|||
Language []string `url:"language,omitempty,comma"` |
|||
Locations []string `url:"locations,omitempty,comma"` |
|||
Replies string `url:"replies,omitempty"` |
|||
StallWarnings *bool `url:"stall_warnings,omitempty"` |
|||
Track []string `url:"track,omitempty,comma"` |
|||
With string `url:"with,omitempty"` |
|||
} |
|||
|
|||
// User returns a stream of messages specific to the authenticated User.
|
|||
// https://dev.twitter.com/streaming/reference/get/user
|
|||
func (srv *StreamService) User(params *StreamUserParams) (*Stream, error) { |
|||
req, err := srv.user.New().Get("user.json").QueryStruct(params).Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newStream(srv.client, req), nil |
|||
} |
|||
|
|||
// StreamSiteParams are the parameters for StreamService.Site.
|
|||
type StreamSiteParams struct { |
|||
FilterLevel string `url:"filter_level,omitempty"` |
|||
Follow []string `url:"follow,omitempty,comma"` |
|||
Language []string `url:"language,omitempty,comma"` |
|||
Replies string `url:"replies,omitempty"` |
|||
StallWarnings *bool `url:"stall_warnings,omitempty"` |
|||
With string `url:"with,omitempty"` |
|||
} |
|||
|
|||
// Site returns messages for a set of users.
|
|||
// Requires special permission to access.
|
|||
// https://dev.twitter.com/streaming/reference/get/site
|
|||
func (srv *StreamService) Site(params *StreamSiteParams) (*Stream, error) { |
|||
req, err := srv.site.New().Get("site.json").QueryStruct(params).Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newStream(srv.client, req), nil |
|||
} |
|||
|
|||
// StreamFirehoseParams are the parameters for StreamService.Firehose.
|
|||
type StreamFirehoseParams struct { |
|||
Count int `url:"count,omitempty"` |
|||
FilterLevel string `url:"filter_level,omitempty"` |
|||
Language []string `url:"language,omitempty,comma"` |
|||
StallWarnings *bool `url:"stall_warnings,omitempty"` |
|||
} |
|||
|
|||
// Firehose returns all public messages and statuses.
|
|||
// Requires special permission to access.
|
|||
// https://dev.twitter.com/streaming/reference/get/statuses/firehose
|
|||
func (srv *StreamService) Firehose(params *StreamFirehoseParams) (*Stream, error) { |
|||
req, err := srv.public.New().Get("firehose.json").QueryStruct(params).Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newStream(srv.client, req), nil |
|||
} |
|||
|
|||
// Stream maintains a connection to the Twitter Streaming API, receives
|
|||
// messages from the streaming response, and sends them on the Messages
|
|||
// channel from a goroutine. The stream goroutine stops itself if an EOF is
|
|||
// reached or retry errors occur, also closing the Messages channel.
|
|||
//
|
|||
// The client must Stop() the stream when finished receiving, which will
|
|||
// wait until the stream is properly stopped.
|
|||
type Stream struct { |
|||
client *http.Client |
|||
Messages chan interface{} |
|||
done chan struct{} |
|||
group *sync.WaitGroup |
|||
body io.Closer |
|||
} |
|||
|
|||
// newStream creates a Stream and starts a goroutine to retry connecting and
|
|||
// receive from a stream response. The goroutine may stop due to retry errors
|
|||
// or be stopped by calling Stop() on the stream.
|
|||
func newStream(client *http.Client, req *http.Request) *Stream { |
|||
s := &Stream{ |
|||
client: client, |
|||
Messages: make(chan interface{}), |
|||
done: make(chan struct{}), |
|||
group: &sync.WaitGroup{}, |
|||
} |
|||
s.group.Add(1) |
|||
go s.retry(req, newExponentialBackOff(), newAggressiveExponentialBackOff()) |
|||
return s |
|||
} |
|||
|
|||
// Stop signals retry and receiver to stop, closes the Messages channel, and
|
|||
// blocks until done.
|
|||
func (s *Stream) Stop() { |
|||
close(s.done) |
|||
// Scanner does not have a Stop() or take a done channel, so for low volume
|
|||
// streams Scan() blocks until the next keep-alive. Close the resp.Body to
|
|||
// escape and stop the stream in a timely fashion.
|
|||
if s.body != nil { |
|||
s.body.Close() |
|||
} |
|||
// block until the retry goroutine stops
|
|||
s.group.Wait() |
|||
} |
|||
|
|||
// retry retries making the given http.Request and receiving the response
|
|||
// according to the Twitter backoff policies. Callers should invoke in a
|
|||
// goroutine since backoffs sleep between retries.
|
|||
// https://dev.twitter.com/streaming/overview/connecting
|
|||
func (s *Stream) retry(req *http.Request, expBackOff backoff.BackOff, aggExpBackOff backoff.BackOff) { |
|||
// close Messages channel and decrement the wait group counter
|
|||
defer close(s.Messages) |
|||
defer s.group.Done() |
|||
|
|||
var wait time.Duration |
|||
for !stopped(s.done) { |
|||
resp, err := s.client.Do(req) |
|||
if err != nil { |
|||
// stop retrying for HTTP protocol errors
|
|||
s.Messages <- err |
|||
return |
|||
} |
|||
// when err is nil, resp contains a non-nil Body which must be closed
|
|||
defer resp.Body.Close() |
|||
s.body = resp.Body |
|||
switch resp.StatusCode { |
|||
case 200: |
|||
// receive stream response Body, handles closing
|
|||
s.receive(resp.Body) |
|||
expBackOff.Reset() |
|||
aggExpBackOff.Reset() |
|||
case 503: |
|||
// exponential backoff
|
|||
wait = expBackOff.NextBackOff() |
|||
case 420, 429: |
|||
// aggressive exponential backoff
|
|||
wait = aggExpBackOff.NextBackOff() |
|||
default: |
|||
// stop retrying for other response codes
|
|||
resp.Body.Close() |
|||
return |
|||
} |
|||
// close response before each retry
|
|||
resp.Body.Close() |
|||
if wait == backoff.Stop { |
|||
return |
|||
} |
|||
sleepOrDone(wait, s.done) |
|||
} |
|||
} |
|||
|
|||
// receive scans a stream response body, JSON decodes tokens to messages, and
|
|||
// sends messages to the Messages channel. Receiving continues until an EOF,
|
|||
// scan error, or the done channel is closed.
|
|||
func (s *Stream) receive(body io.ReadCloser) { |
|||
defer body.Close() |
|||
// A bufio.Scanner steps through 'tokens' of data on each Scan() using a
|
|||
// SplitFunc. SplitFunc tokenizes input bytes to return the number of bytes
|
|||
// to advance, the token slice of bytes, and any errors.
|
|||
scanner := bufio.NewScanner(body) |
|||
// default ScanLines SplitFunc is incorrect for Twitter Streams, set custom
|
|||
scanner.Split(scanLines) |
|||
for !stopped(s.done) && scanner.Scan() { |
|||
token := scanner.Bytes() |
|||
if len(token) == 0 { |
|||
// empty keep-alive
|
|||
continue |
|||
} |
|||
select { |
|||
// send messages, data, or errors
|
|||
case s.Messages <- getMessage(token): |
|||
continue |
|||
// allow client to Stop(), even if not receiving
|
|||
case <-s.done: |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// getMessage unmarshals the token and returns a message struct, if the type
|
|||
// can be determined. Otherwise, returns the token unmarshalled into a data
|
|||
// map[string]interface{} or the unmarshal error.
|
|||
func getMessage(token []byte) interface{} { |
|||
var data map[string]interface{} |
|||
// unmarshal JSON encoded token into a map for
|
|||
err := json.Unmarshal(token, &data) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
return decodeMessage(token, data) |
|||
} |
|||
|
|||
// decodeMessage determines the message type from known data keys, allocates
|
|||
// at most one message struct, and JSON decodes the token into the message.
|
|||
// Returns the message struct or the data map if the message type could not be
|
|||
// determined.
|
|||
func decodeMessage(token []byte, data map[string]interface{}) interface{} { |
|||
if hasPath(data, "retweet_count") { |
|||
tweet := new(Tweet) |
|||
json.Unmarshal(token, tweet) |
|||
return tweet |
|||
} else if hasPath(data, "direct_message") { |
|||
notice := new(directMessageNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.DirectMessage |
|||
} else if hasPath(data, "delete") { |
|||
notice := new(statusDeletionNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.Delete.StatusDeletion |
|||
} else if hasPath(data, "scrub_geo") { |
|||
notice := new(locationDeletionNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.ScrubGeo |
|||
} else if hasPath(data, "limit") { |
|||
notice := new(streamLimitNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.Limit |
|||
} else if hasPath(data, "status_withheld") { |
|||
notice := new(statusWithheldNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.StatusWithheld |
|||
} else if hasPath(data, "user_withheld") { |
|||
notice := new(userWithheldNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.UserWithheld |
|||
} else if hasPath(data, "disconnect") { |
|||
notice := new(streamDisconnectNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.StreamDisconnect |
|||
} else if hasPath(data, "warning") { |
|||
notice := new(stallWarningNotice) |
|||
json.Unmarshal(token, notice) |
|||
return notice.StallWarning |
|||
} else if hasPath(data, "friends") { |
|||
friendsList := new(FriendsList) |
|||
json.Unmarshal(token, friendsList) |
|||
return friendsList |
|||
} else if hasPath(data, "event") { |
|||
event := new(Event) |
|||
json.Unmarshal(token, event) |
|||
return event |
|||
} |
|||
// message type unknown, return the data map[string]interface{}
|
|||
return data |
|||
} |
|||
|
|||
// hasPath returns true if the map contains the given key, false otherwise.
|
|||
func hasPath(data map[string]interface{}, key string) bool { |
|||
_, ok := data[key] |
|||
return ok |
|||
} |
@ -0,0 +1,352 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"sync" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestStream_MessageJSONError(t *testing.T) { |
|||
badJSON := []byte(`{`) |
|||
msg := getMessage(badJSON) |
|||
assert.EqualError(t, msg.(error), "unexpected end of JSON input") |
|||
} |
|||
|
|||
func TestStream_GetMessageTweet(t *testing.T) { |
|||
msgJSON := []byte(`{"id": 20, "text": "just setting up my twttr", "retweet_count": "68535"}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &Tweet{}, msg) |
|||
} |
|||
|
|||
func TestStream_GetMessageDirectMessage(t *testing.T) { |
|||
msgJSON := []byte(`{"direct_message": {"id": 666024290140217347}}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &DirectMessage{}, msg) |
|||
} |
|||
|
|||
func TestStream_GetMessageDelete(t *testing.T) { |
|||
msgJSON := []byte(`{"delete": { "id": 20}}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &StatusDeletion{}, msg) |
|||
} |
|||
|
|||
func TestStream_GetMessageLocationDeletion(t *testing.T) { |
|||
msgJSON := []byte(`{"scrub_geo": { "up_to_status_id": 20}}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &LocationDeletion{}, msg) |
|||
} |
|||
|
|||
func TestStream_GetMessageStreamLimit(t *testing.T) { |
|||
msgJSON := []byte(`{"limit": { "track": 10 }}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &StreamLimit{}, msg) |
|||
} |
|||
|
|||
func TestStream_StatusWithheld(t *testing.T) { |
|||
msgJSON := []byte(`{"status_withheld": { "id": 20, "user_id": 12, "withheld_in_countries":["USA", "China"] }}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &StatusWithheld{}, msg) |
|||
} |
|||
|
|||
func TestStream_UserWithheld(t *testing.T) { |
|||
msgJSON := []byte(`{"user_withheld": { "id": 12, "withheld_in_countries":["USA", "China"] }}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &UserWithheld{}, msg) |
|||
} |
|||
|
|||
func TestStream_StreamDisconnect(t *testing.T) { |
|||
msgJSON := []byte(`{"disconnect": { "code": "420", "stream_name": "streaming stuff", "reason": "too many connections" }}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &StreamDisconnect{}, msg) |
|||
} |
|||
|
|||
func TestStream_StallWarning(t *testing.T) { |
|||
msgJSON := []byte(`{"warning": { "code": "420", "percent_full": 90, "message": "a lot of messages" }}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &StallWarning{}, msg) |
|||
} |
|||
|
|||
func TestStream_FriendsList(t *testing.T) { |
|||
msgJSON := []byte(`{"friends": [666024290140217347, 666024290140217349, 666024290140217342]}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &FriendsList{}, msg) |
|||
} |
|||
|
|||
func TestStream_Event(t *testing.T) { |
|||
msgJSON := []byte(`{"event": "block", "target": {"name": "XKCD Comic", "favourites_count": 2}, "source": {"name": "XKCD Comic2", "favourites_count": 3}, "created_at": "Sat Sep 4 16:10:54 +0000 2010"}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, &Event{}, msg) |
|||
} |
|||
|
|||
func TestStream_Unknown(t *testing.T) { |
|||
msgJSON := []byte(`{"unknown_data": {"new_twitter_type":"unexpected"}}`) |
|||
msg := getMessage(msgJSON) |
|||
assert.IsType(t, map[string]interface{}{}, msg) |
|||
} |
|||
|
|||
func TestStream_Filter(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/1.1/statuses/filter.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{"track": "gophercon,golang"}, r) |
|||
switch reqCount { |
|||
case 0: |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Header().Set("Transfer-Encoding", "chunked") |
|||
fmt.Fprintf(w, |
|||
`{"text": "Gophercon talks!"}`+"\r\n"+ |
|||
`{"text": "Gophercon super talks!"}`+"\r\n", |
|||
) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
|
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
client := NewClient(httpClient) |
|||
streamFilterParams := &StreamFilterParams{ |
|||
Track: []string{"gophercon", "golang"}, |
|||
} |
|||
stream, err := client.Streams.Filter(streamFilterParams) |
|||
// assert that the expected messages are received
|
|||
assert.NoError(t, err) |
|||
defer stream.Stop() |
|||
for message := range stream.Messages { |
|||
demux.Handle(message) |
|||
} |
|||
expectedCounts := &counter{all: 2, other: 2} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestStream_Sample(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/1.1/statuses/sample.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"stall_warnings": "true"}, r) |
|||
switch reqCount { |
|||
case 0: |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Header().Set("Transfer-Encoding", "chunked") |
|||
fmt.Fprintf(w, |
|||
`{"text": "Gophercon talks!"}`+"\r\n"+ |
|||
`{"text": "Gophercon super talks!"}`+"\r\n", |
|||
) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
|
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
client := NewClient(httpClient) |
|||
streamSampleParams := &StreamSampleParams{ |
|||
StallWarnings: Bool(true), |
|||
} |
|||
stream, err := client.Streams.Sample(streamSampleParams) |
|||
// assert that the expected messages are received
|
|||
assert.NoError(t, err) |
|||
defer stream.Stop() |
|||
for message := range stream.Messages { |
|||
demux.Handle(message) |
|||
} |
|||
expectedCounts := &counter{all: 2, other: 2} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestStream_User(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/1.1/user.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"stall_warnings": "true", "with": "followings"}, r) |
|||
switch reqCount { |
|||
case 0: |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Header().Set("Transfer-Encoding", "chunked") |
|||
fmt.Fprintf(w, `{"friends": [666024290140217347, 666024290140217349, 666024290140217342]}`+"\r\n"+"\r\n") |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
|
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
client := NewClient(httpClient) |
|||
streamUserParams := &StreamUserParams{ |
|||
StallWarnings: Bool(true), |
|||
With: "followings", |
|||
} |
|||
stream, err := client.Streams.User(streamUserParams) |
|||
// assert that the expected messages are received
|
|||
assert.NoError(t, err) |
|||
defer stream.Stop() |
|||
for message := range stream.Messages { |
|||
demux.Handle(message) |
|||
} |
|||
expectedCounts := &counter{all: 1, friendsList: 1} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestStream_Site(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/1.1/site.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"follow": "666024290140217347,666024290140217349"}, r) |
|||
switch reqCount { |
|||
case 0: |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Header().Set("Transfer-Encoding", "chunked") |
|||
fmt.Fprintf(w, |
|||
`{"text": "Gophercon talks!"}`+"\r\n"+ |
|||
`{"text": "Gophercon super talks!"}`+"\r\n", |
|||
) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
|
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
client := NewClient(httpClient) |
|||
streamSiteParams := &StreamSiteParams{ |
|||
Follow: []string{"666024290140217347", "666024290140217349"}, |
|||
} |
|||
stream, err := client.Streams.Site(streamSiteParams) |
|||
// assert that the expected messages are received
|
|||
assert.NoError(t, err) |
|||
defer stream.Stop() |
|||
for message := range stream.Messages { |
|||
demux.Handle(message) |
|||
} |
|||
expectedCounts := &counter{all: 2, other: 2} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestStream_PublicFirehose(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/1.1/statuses/firehose.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"count": "100"}, r) |
|||
switch reqCount { |
|||
case 0: |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.Header().Set("Transfer-Encoding", "chunked") |
|||
fmt.Fprintf(w, |
|||
`{"text": "Gophercon talks!"}`+"\r\n"+ |
|||
`{"text": "Gophercon super talks!"}`+"\r\n", |
|||
) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
|
|||
counts := &counter{} |
|||
demux := newCounterDemux(counts) |
|||
client := NewClient(httpClient) |
|||
streamFirehoseParams := &StreamFirehoseParams{ |
|||
Count: 100, |
|||
} |
|||
stream, err := client.Streams.Firehose(streamFirehoseParams) |
|||
// assert that the expected messages are received
|
|||
assert.NoError(t, err) |
|||
defer stream.Stop() |
|||
for message := range stream.Messages { |
|||
demux.Handle(message) |
|||
} |
|||
expectedCounts := &counter{all: 2, other: 2} |
|||
assert.Equal(t, expectedCounts, counts) |
|||
} |
|||
|
|||
func TestStreamRetry_ExponentialBackoff(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { |
|||
switch reqCount { |
|||
case 0: |
|||
http.Error(w, "Service Unavailable", 503) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
stream := &Stream{ |
|||
client: httpClient, |
|||
Messages: make(chan interface{}), |
|||
done: make(chan struct{}), |
|||
group: &sync.WaitGroup{}, |
|||
} |
|||
stream.group.Add(1) |
|||
req, _ := http.NewRequest("GET", "http://example.com/", nil) |
|||
expBackoff := &BackOffRecorder{} |
|||
// receive messages and throw them away
|
|||
go NewSwitchDemux().HandleChan(stream.Messages) |
|||
stream.retry(req, expBackoff, nil) |
|||
defer stream.Stop() |
|||
// assert exponential backoff in response to 503
|
|||
assert.Equal(t, 1, expBackoff.Count) |
|||
} |
|||
|
|||
func TestStreamRetry_AggressiveBackoff(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
reqCount := 0 |
|||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { |
|||
switch reqCount { |
|||
case 0: |
|||
http.Error(w, "Enhance Your Calm", 420) |
|||
case 1: |
|||
http.Error(w, "Too Many Requests", 429) |
|||
default: |
|||
// Only allow first request
|
|||
http.Error(w, "Stream API not available!", 130) |
|||
} |
|||
reqCount++ |
|||
}) |
|||
stream := &Stream{ |
|||
client: httpClient, |
|||
Messages: make(chan interface{}), |
|||
done: make(chan struct{}), |
|||
group: &sync.WaitGroup{}, |
|||
} |
|||
stream.group.Add(1) |
|||
req, _ := http.NewRequest("GET", "http://example.com/", nil) |
|||
aggExpBackoff := &BackOffRecorder{} |
|||
// receive messages and throw them away
|
|||
go NewSwitchDemux().HandleChan(stream.Messages) |
|||
stream.retry(req, nil, aggExpBackoff) |
|||
defer stream.Stop() |
|||
// assert aggressive exponential backoff in response to 420 and 429
|
|||
assert.Equal(t, 2, aggExpBackoff.Count) |
|||
} |
@ -0,0 +1,106 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// TimelineService provides methods for accessing Twitter status timeline
|
|||
// API endpoints.
|
|||
type TimelineService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newTimelineService returns a new TimelineService.
|
|||
func newTimelineService(sling *sling.Sling) *TimelineService { |
|||
return &TimelineService{ |
|||
sling: sling.Path("statuses/"), |
|||
} |
|||
} |
|||
|
|||
// UserTimelineParams are the parameters for TimelineService.UserTimeline.
|
|||
type UserTimelineParams struct { |
|||
UserID int64 `url:"user_id,omitempty"` |
|||
ScreenName string `url:"screen_name,omitempty"` |
|||
Count int `url:"count,omitempty"` |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
ExcludeReplies *bool `url:"exclude_replies,omitempty"` |
|||
ContributorDetails *bool `url:"contributor_details,omitempty"` |
|||
IncludeRetweets *bool `url:"include_rts,omitempty"` |
|||
} |
|||
|
|||
// UserTimeline returns recent Tweets from the specified user.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/user_timeline
|
|||
func (s *TimelineService) UserTimeline(params *UserTimelineParams) ([]Tweet, *http.Response, error) { |
|||
tweets := new([]Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("user_timeline.json").QueryStruct(params).Receive(tweets, apiError) |
|||
return *tweets, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// HomeTimelineParams are the parameters for TimelineService.HomeTimeline.
|
|||
type HomeTimelineParams struct { |
|||
Count int `url:"count,omitempty"` |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
ExcludeReplies *bool `url:"exclude_replies,omitempty"` |
|||
ContributorDetails *bool `url:"contributor_details,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
} |
|||
|
|||
// HomeTimeline returns recent Tweets and retweets from the user and those
|
|||
// users they follow.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/home_timeline
|
|||
func (s *TimelineService) HomeTimeline(params *HomeTimelineParams) ([]Tweet, *http.Response, error) { |
|||
tweets := new([]Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("home_timeline.json").QueryStruct(params).Receive(tweets, apiError) |
|||
return *tweets, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// MentionTimelineParams are the parameters for TimelineService.MentionTimeline.
|
|||
type MentionTimelineParams struct { |
|||
Count int `url:"count,omitempty"` |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
ContributorDetails *bool `url:"contributor_details,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
} |
|||
|
|||
// MentionTimeline returns recent Tweet mentions of the authenticated user.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/mentions_timeline
|
|||
func (s *TimelineService) MentionTimeline(params *MentionTimelineParams) ([]Tweet, *http.Response, error) { |
|||
tweets := new([]Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("mentions_timeline.json").QueryStruct(params).Receive(tweets, apiError) |
|||
return *tweets, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// RetweetsOfMeTimelineParams are the parameters for
|
|||
// TimelineService.RetweetsOfMeTimeline.
|
|||
type RetweetsOfMeTimelineParams struct { |
|||
Count int `url:"count,omitempty"` |
|||
SinceID int64 `url:"since_id,omitempty"` |
|||
MaxID int64 `url:"max_id,omitempty"` |
|||
TrimUser *bool `url:"trim_user,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` |
|||
IncludeUserEntities *bool `url:"include_user_entities"` |
|||
} |
|||
|
|||
// RetweetsOfMeTimeline returns the most recent Tweets by the authenticated
|
|||
// user that have been retweeted by others.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/get/statuses/retweets_of_me
|
|||
func (s *TimelineService) RetweetsOfMeTimeline(params *RetweetsOfMeTimelineParams) ([]Tweet, *http.Response, error) { |
|||
tweets := new([]Tweet) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("retweets_of_me.json").QueryStruct(params).Receive(tweets, apiError) |
|||
return *tweets, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,81 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestTimelineService_UserTimeline(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/user_timeline.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"user_id": "113419064", "trim_user": "true", "include_rts": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"text": "Gophercon talks!"}, {"text": "Why gophers are so adorable"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
tweets, _, err := client.Timelines.UserTimeline(&UserTimelineParams{UserID: 113419064, TrimUser: Bool(true), IncludeRetweets: Bool(false)}) |
|||
expected := []Tweet{Tweet{Text: "Gophercon talks!"}, Tweet{Text: "Why gophers are so adorable"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweets) |
|||
} |
|||
|
|||
func TestTimelineService_HomeTimeline(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/home_timeline.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"since_id": "589147592367431680", "exclude_replies": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"text": "Live on #Periscope"}, {"text": "Clickbait journalism"}, {"text": "Useful announcement"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
tweets, _, err := client.Timelines.HomeTimeline(&HomeTimelineParams{SinceID: 589147592367431680, ExcludeReplies: Bool(false)}) |
|||
expected := []Tweet{Tweet{Text: "Live on #Periscope"}, Tweet{Text: "Clickbait journalism"}, Tweet{Text: "Useful announcement"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweets) |
|||
} |
|||
|
|||
func TestTimelineService_MentionTimeline(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/mentions_timeline.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"count": "20", "include_entities": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"text": "@dghubble can I get verified?"}, {"text": "@dghubble why are gophers so great?"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
tweets, _, err := client.Timelines.MentionTimeline(&MentionTimelineParams{Count: 20, IncludeEntities: Bool(false)}) |
|||
expected := []Tweet{Tweet{Text: "@dghubble can I get verified?"}, Tweet{Text: "@dghubble why are gophers so great?"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweets) |
|||
} |
|||
|
|||
func TestTimelineService_RetweetsOfMeTimeline(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/statuses/retweets_of_me.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"trim_user": "false", "include_user_entities": "false"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"text": "RT Twitter UK edition"}, {"text": "RT Triply-replicated Gophers"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
tweets, _, err := client.Timelines.RetweetsOfMeTimeline(&RetweetsOfMeTimelineParams{TrimUser: Bool(false), IncludeUserEntities: Bool(false)}) |
|||
expected := []Tweet{Tweet{Text: "RT Twitter UK edition"}, Tweet{Text: "RT Triply-replicated Gophers"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, tweets) |
|||
} |
@ -0,0 +1,51 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
const twitterAPI = "https://api.twitter.com/1.1/" |
|||
|
|||
// Client is a Twitter client for making Twitter API requests.
|
|||
type Client struct { |
|||
sling *sling.Sling |
|||
// Twitter API Services
|
|||
Accounts *AccountService |
|||
Statuses *StatusService |
|||
Timelines *TimelineService |
|||
Users *UserService |
|||
Followers *FollowerService |
|||
DirectMessages *DirectMessageService |
|||
Streams *StreamService |
|||
} |
|||
|
|||
// NewClient returns a new Client.
|
|||
func NewClient(httpClient *http.Client) *Client { |
|||
base := sling.New().Client(httpClient).Base(twitterAPI) |
|||
return &Client{ |
|||
sling: base, |
|||
Accounts: newAccountService(base.New()), |
|||
Statuses: newStatusService(base.New()), |
|||
Timelines: newTimelineService(base.New()), |
|||
Users: newUserService(base.New()), |
|||
Followers: newFollowerService(base.New()), |
|||
DirectMessages: newDirectMessageService(base.New()), |
|||
Streams: newStreamService(httpClient, base.New()), |
|||
} |
|||
} |
|||
|
|||
// Bool returns a new pointer to the given bool value.
|
|||
func Bool(v bool) *bool { |
|||
ptr := new(bool) |
|||
*ptr = v |
|||
return ptr |
|||
} |
|||
|
|||
// Float returns a new pointer to the given float64 value.
|
|||
func Float(v float64) *float64 { |
|||
ptr := new(float64) |
|||
*ptr = v |
|||
return ptr |
|||
} |
@ -0,0 +1,93 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"net/url" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
var defaultTestTimeout = time.Second * 1 |
|||
|
|||
// testServer returns an http Client, ServeMux, and Server. The client proxies
|
|||
// requests to the server and handlers can be registered on the mux to handle
|
|||
// requests. The caller must close the test server.
|
|||
func testServer() (*http.Client, *http.ServeMux, *httptest.Server) { |
|||
mux := http.NewServeMux() |
|||
server := httptest.NewServer(mux) |
|||
transport := &RewriteTransport{&http.Transport{ |
|||
Proxy: func(req *http.Request) (*url.URL, error) { |
|||
return url.Parse(server.URL) |
|||
}, |
|||
}} |
|||
client := &http.Client{Transport: transport} |
|||
return client, mux, server |
|||
} |
|||
|
|||
// RewriteTransport rewrites https requests to http to avoid TLS cert issues
|
|||
// during testing.
|
|||
type RewriteTransport struct { |
|||
Transport http.RoundTripper |
|||
} |
|||
|
|||
// RoundTrip rewrites the request scheme to http and calls through to the
|
|||
// composed RoundTripper or if it is nil, to the http.DefaultTransport.
|
|||
func (t *RewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
req.URL.Scheme = "http" |
|||
if t.Transport == nil { |
|||
return http.DefaultTransport.RoundTrip(req) |
|||
} |
|||
return t.Transport.RoundTrip(req) |
|||
} |
|||
|
|||
func assertMethod(t *testing.T, expectedMethod string, req *http.Request) { |
|||
assert.Equal(t, expectedMethod, req.Method) |
|||
} |
|||
|
|||
// assertQuery tests that the Request has the expected url query key/val pairs
|
|||
func assertQuery(t *testing.T, expected map[string]string, req *http.Request) { |
|||
queryValues := req.URL.Query() |
|||
expectedValues := url.Values{} |
|||
for key, value := range expected { |
|||
expectedValues.Add(key, value) |
|||
} |
|||
assert.Equal(t, expectedValues, queryValues) |
|||
} |
|||
|
|||
// assertPostForm tests that the Request has the expected key values pairs url
|
|||
// encoded in its Body
|
|||
func assertPostForm(t *testing.T, expected map[string]string, req *http.Request) { |
|||
req.ParseForm() // parses request Body to put url.Values in r.Form/r.PostForm
|
|||
expectedValues := url.Values{} |
|||
for key, value := range expected { |
|||
expectedValues.Add(key, value) |
|||
} |
|||
assert.Equal(t, expectedValues, req.Form) |
|||
} |
|||
|
|||
// assertDone asserts that the empty struct channel is closed before the given
|
|||
// timeout elapses.
|
|||
func assertDone(t *testing.T, ch <-chan struct{}, timeout time.Duration) { |
|||
select { |
|||
case <-ch: |
|||
_, more := <-ch |
|||
assert.False(t, more) |
|||
case <-time.After(timeout): |
|||
t.Errorf("expected channel to be closed within timeout %v", timeout) |
|||
} |
|||
} |
|||
|
|||
// assertClosed asserts that the channel is closed before the given timeout
|
|||
// elapses.
|
|||
func assertClosed(t *testing.T, ch <-chan interface{}, timeout time.Duration) { |
|||
select { |
|||
case <-ch: |
|||
_, more := <-ch |
|||
assert.False(t, more) |
|||
case <-time.After(timeout): |
|||
t.Errorf("expected channel to be closed within timeout %v", timeout) |
|||
} |
|||
} |
@ -0,0 +1,122 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"github.com/dghubble/sling" |
|||
) |
|||
|
|||
// User represents a Twitter User.
|
|||
// https://dev.twitter.com/overview/api/users
|
|||
type User struct { |
|||
ContributorsEnabled bool `json:"contributors_enabled"` |
|||
CreatedAt string `json:"created_at"` |
|||
DefaultProfile bool `json:"default_profile"` |
|||
DefaultProfileImage bool `json:"default_profile_image"` |
|||
Description string `json:"description"` |
|||
Email string `json:"email"` |
|||
Entities *UserEntities `json:"entities"` |
|||
FavouritesCount int `json:"favourites_count"` |
|||
FollowRequestSent bool `json:"follow_request_sent"` |
|||
Following bool `json:"following"` |
|||
FollowersCount int `json:"followers_count"` |
|||
FriendsCount int `json:"friends_count"` |
|||
GeoEnabled bool `json:"geo_enabled"` |
|||
ID int64 `json:"id"` |
|||
IDStr string `json:"id_str"` |
|||
IsTranslator bool `json:"id_translator"` |
|||
Lang string `json:"lang"` |
|||
ListedCount int `json:"listed_count"` |
|||
Location string `json:"location"` |
|||
Name string `json:"name"` |
|||
Notifications bool `json:"notifications"` |
|||
ProfileBackgroundColor string `json:"profile_background_color"` |
|||
ProfileBackgroundImageURL string `json:"profile_background_image_url"` |
|||
ProfileBackgroundImageURLHttps string `json:"profile_background_image_url_https"` |
|||
ProfileBackgroundTile bool `json:"profile_background_tile"` |
|||
ProfileBannerURL string `json:"profile_banner_url"` |
|||
ProfileImageURL string `json:"profile_image_url"` |
|||
ProfileImageURLHttps string `json:"profile_image_url_https"` |
|||
ProfileLinkColor string `json:"profile_link_color"` |
|||
ProfileSidebarBorderColor string `json:"profile_sidebar_border_color"` |
|||
ProfileSidebarFillColor string `json:"profile_sidebar_fill_color"` |
|||
ProfileTextColor string `json:"profile_text_color"` |
|||
ProfileUseBackgroundImage bool `json:"profile_use_background_image"` |
|||
Protected bool `json:"protected"` |
|||
ScreenName string `json:"screen_name"` |
|||
ShowAllInlineMedia bool `json:"show_all_inline_media"` |
|||
Status *Tweet `json:"status"` |
|||
StatusesCount int `json:"statuses_count"` |
|||
Timezone string `json:"time_zone"` |
|||
URL string `json:"url"` |
|||
UtcOffset int `json:"utc_offset"` |
|||
Verified bool `json:"verified"` |
|||
WithheldInCountries string `json:"withheld_in_countries"` |
|||
WithholdScope string `json:"withheld_scope"` |
|||
} |
|||
|
|||
// UserService provides methods for accessing Twitter user API endpoints.
|
|||
type UserService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// newUserService returns a new UserService.
|
|||
func newUserService(sling *sling.Sling) *UserService { |
|||
return &UserService{ |
|||
sling: sling.Path("users/"), |
|||
} |
|||
} |
|||
|
|||
// UserShowParams are the parameters for UserService.Show.
|
|||
type UserShowParams struct { |
|||
UserID int64 `url:"user_id,omitempty"` |
|||
ScreenName string `url:"screen_name,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` // whether 'status' should include entities
|
|||
} |
|||
|
|||
// Show returns the requested User.
|
|||
// https://dev.twitter.com/rest/reference/get/users/show
|
|||
func (s *UserService) Show(params *UserShowParams) (*User, *http.Response, error) { |
|||
user := new(User) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("show.json").QueryStruct(params).Receive(user, apiError) |
|||
return user, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// UserLookupParams are the parameters for UserService.Lookup.
|
|||
type UserLookupParams struct { |
|||
UserID []int64 `url:"user_id,omitempty,comma"` |
|||
ScreenName []string `url:"screen_name,omitempty,comma"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` // whether 'status' should include entities
|
|||
} |
|||
|
|||
// Lookup returns the requested Users as a slice.
|
|||
// https://dev.twitter.com/rest/reference/get/users/lookup
|
|||
func (s *UserService) Lookup(params *UserLookupParams) ([]User, *http.Response, error) { |
|||
users := new([]User) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("lookup.json").QueryStruct(params).Receive(users, apiError) |
|||
return *users, resp, relevantError(err, *apiError) |
|||
} |
|||
|
|||
// UserSearchParams are the parameters for UserService.Search.
|
|||
type UserSearchParams struct { |
|||
Query string `url:"q,omitempty"` |
|||
Page int `url:"page,omitempty"` // 1-based page number
|
|||
Count int `url:"count,omitempty"` |
|||
IncludeEntities *bool `url:"include_entities,omitempty"` // whether 'status' should include entities
|
|||
} |
|||
|
|||
// Search queries public user accounts.
|
|||
// Requires a user auth context.
|
|||
// https://dev.twitter.com/rest/reference/get/users/search
|
|||
func (s *UserService) Search(query string, params *UserSearchParams) ([]User, *http.Response, error) { |
|||
if params == nil { |
|||
params = &UserSearchParams{} |
|||
} |
|||
params.Query = query |
|||
users := new([]User) |
|||
apiError := new(APIError) |
|||
resp, err := s.sling.New().Get("search.json").QueryStruct(params).Receive(users, apiError) |
|||
return *users, resp, relevantError(err, *apiError) |
|||
} |
@ -0,0 +1,92 @@ |
|||
package twitter |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestUserService_Show(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/users/show.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"screen_name": "xkcdComic"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"name": "XKCD Comic", "favourites_count": 2}`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
user, _, err := client.Users.Show(&UserShowParams{ScreenName: "xkcdComic"}) |
|||
expected := &User{Name: "XKCD Comic", FavouritesCount: 2} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, user) |
|||
} |
|||
|
|||
func TestUserService_LookupWithIds(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/users/lookup.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"user_id": "113419064,623265148"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"screen_name": "golang"}, {"screen_name": "dghubble"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
users, _, err := client.Users.Lookup(&UserLookupParams{UserID: []int64{113419064, 623265148}}) |
|||
expected := []User{User{ScreenName: "golang"}, User{ScreenName: "dghubble"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, users) |
|||
} |
|||
|
|||
func TestUserService_LookupWithScreenNames(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/users/lookup.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"screen_name": "foo,bar"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"name": "Foo"}, {"name": "Bar"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
users, _, err := client.Users.Lookup(&UserLookupParams{ScreenName: []string{"foo", "bar"}}) |
|||
expected := []User{User{Name: "Foo"}, User{Name: "Bar"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, users) |
|||
} |
|||
|
|||
func TestUserService_Search(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/users/search.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "GET", r) |
|||
assertQuery(t, map[string]string{"count": "11", "q": "news"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `[{"name": "BBC"}, {"name": "BBC Breaking News"}]`) |
|||
}) |
|||
|
|||
client := NewClient(httpClient) |
|||
users, _, err := client.Users.Search("news", &UserSearchParams{Query: "override me", Count: 11}) |
|||
expected := []User{User{Name: "BBC"}, User{Name: "BBC Breaking News"}} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, users) |
|||
} |
|||
|
|||
func TestUserService_SearchHandlesNilParams(t *testing.T) { |
|||
httpClient, mux, server := testServer() |
|||
defer server.Close() |
|||
|
|||
mux.HandleFunc("/1.1/users/search.json", func(w http.ResponseWriter, r *http.Request) { |
|||
assertQuery(t, map[string]string{"q": "news"}, r) |
|||
}) |
|||
client := NewClient(httpClient) |
|||
client.Users.Search("news", nil) |
|||
} |
@ -0,0 +1,30 @@ |
|||
# OAuth1 Changelog |
|||
|
|||
## v0.4.0 (2016-04-20) |
|||
|
|||
* Add a Signer field to the Config to allow custom Signer implementations. |
|||
* Use the HMACSigner by default. This provides the same signing behavior as in previous versions (HMAC-SHA1). |
|||
* Add an RSASigner for "RSA-SHA1" OAuth1 Providers. |
|||
* Add missing Authorization Header quotes around OAuth parameter values. Many providers allowed these quotes to be missing. |
|||
* Change `Signer` to be a signer interface. |
|||
* Remove the old Signer methods `SetAccessTokenAuthHeader`, `SetRequestAuthHeader`, and `SetRequestTokenAuthHeader`. |
|||
|
|||
## v0.3.0 (2015-09-13) |
|||
|
|||
* Added `NoContext` which may be used in most cases. |
|||
* Allowed Transport Base http.RoundTripper to be set through a ctx. |
|||
* Changed `NewClient` to require a context.Context. |
|||
* Changed `Config.Client` to require a context.Context. |
|||
|
|||
## v.0.2.0 (2015-08-30) |
|||
|
|||
* Improved OAuth 1 spec compliance and test coverage. |
|||
* Added `func StaticTokenSource(*Token) TokenSource` |
|||
* Added `ParseAuthorizationCallback` function. Removed `Config.HandleAuthorizationCallback` method. |
|||
* Changed `Config` method signatures to allow an interface to be defined for the OAuth1 authorization flow. Gives users of this package (and downstream packages) the freedom to use other implementations if they wish. |
|||
* Removed `RequestToken` in favor of passing token and secret value strings. |
|||
* Removed `ReuseTokenSource` struct, it was effectively a static source. Replaced by `StaticTokenSource`. |
|||
|
|||
## v0.1.0 (2015-04-26) |
|||
|
|||
* Initial OAuth1 support for obtaining authorization and making authorized requests. |
@ -0,0 +1,21 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2015 Dalton Hubble |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
of this software and associated documentation files (the "Software"), to deal |
|||
in the Software without restriction, including without limitation the rights |
|||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
copies of the Software, and to permit persons to whom the Software is |
|||
furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
|||
THE SOFTWARE. |
@ -0,0 +1,125 @@ |
|||
|
|||
# OAuth1 [![Build Status](https://travis-ci.org/dghubble/oauth1.png)](https://travis-ci.org/dghubble/oauth1) [![GoDoc](http://godoc.org/github.com/dghubble/oauth1?status.png)](http://godoc.org/github.com/dghubble/oauth1) |
|||
<img align="right" src="https://storage.googleapis.com/dghubble/oauth1.png"> |
|||
|
|||
OAauth1 is a Go implementation of the [OAuth 1 spec](https://tools.ietf.org/html/rfc5849). |
|||
|
|||
It allows end-users to authorize a client (consumer) to access protected resources on his/her behalf and to make signed and authorized requests. |
|||
|
|||
Package `oauth1` takes design cues from [golang.org/x/oauth2](https://godoc.org/golang.org/x/oauth2), to provide an analogous API and an `http.Client` with a Transport which signs/authorizes requests. |
|||
|
|||
## Install |
|||
|
|||
go get github.com/dghubble/oauth1 |
|||
|
|||
## Docs |
|||
|
|||
Read [GoDoc](https://godoc.org/github.com/dghubble/oauth1) |
|||
|
|||
## Usage |
|||
|
|||
Package `oauth1` implements the OAuth1 authorization flow and provides an `http.Client` which can sign and authorize OAuth1 requests. |
|||
|
|||
To implement "Login with X", use the [gologin](https://github.com/dghubble/gologin) packages which provide login handlers for OAuth1 and OAuth2 providers. |
|||
|
|||
To call the Twitter, Digits, or Tumblr OAuth1 APIs, use the higher level Go API clients. |
|||
|
|||
* [Twitter](https://github.com/dghubble/go-twitter) |
|||
* [Digits](https://github.com/dghubble/go-digits) |
|||
* [Tumblr](https://github.com/benfb/go-tumblr) |
|||
|
|||
### Authorization Flow |
|||
|
|||
Perform the OAuth 1 authorization flow to ask a user to grant an application access to his/her resources via an access token. |
|||
|
|||
```go |
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
"github.com/dghubble/oauth1/twitter"" |
|||
) |
|||
... |
|||
|
|||
config := oauth1.Config{ |
|||
ConsumerKey: "consumerKey", |
|||
ConsumerSecret: "consumerSecret", |
|||
CallbackURL: "http://mysite.com/oauth/twitter/callback", |
|||
Endpoint: twitter.AuthorizeEndpoint, |
|||
} |
|||
``` |
|||
|
|||
1. When a user performs an action (e.g. "Login with X" button calls "/login" route) get an OAuth1 request token (temporary credentials). |
|||
|
|||
```go |
|||
requestToken, requestSecret, err = config.RequestToken() |
|||
// handle err |
|||
``` |
|||
|
|||
2. Obtain authorization from the user by redirecting them to the OAuth1 provider's authorization URL to grant the application access. |
|||
|
|||
```go |
|||
authorizationURL, err := config.AuthorizationURL(requestToken) |
|||
// handle err |
|||
http.Redirect(w, req, authorizationURL.String(), htt.StatusFound) |
|||
``` |
|||
|
|||
Receive the callback from the OAuth1 provider in a handler. |
|||
|
|||
```go |
|||
requestToken, verifier, err := oauth1.ParseAuthorizationCallback(req) |
|||
// handle err |
|||
``` |
|||
|
|||
3. Acquire the access token (token credentials) which can later be used to make requests on behalf of the user. |
|||
|
|||
```go |
|||
accessToken, accessSecret, err := config.AccessToken(requestToken, requestSecret, verifier) |
|||
// handle error |
|||
token := NewToken(accessToken, accessSecret) |
|||
``` |
|||
|
|||
Check the [examples](examples) to see this authorization flow in action from the command line, with Twitter PIN-based login and Tumblr login. |
|||
|
|||
### Authorized Requests |
|||
|
|||
Use an access `Token` to make authorized requests on behalf of a user. |
|||
|
|||
```go |
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
func main() { |
|||
config := oauth1.NewConfig("consumerKey", "consumerSecret") |
|||
token := oauth1.NewToken("token", "tokenSecret") |
|||
|
|||
// httpClient will automatically authorize http.Request's |
|||
httpClient := config.Client(oauth1.NoContext, token) |
|||
|
|||
// example Twitter API request |
|||
path := "https://api.twitter.com/1.1/statuses/home_timeline.json?count=2" |
|||
resp, _ := httpClient.Get(path) |
|||
defer resp.Body.Close() |
|||
body, _ := ioutil.ReadAll(resp.Body) |
|||
fmt.Printf("Raw Response Body:\n%v\n", string(body)) |
|||
} |
|||
``` |
|||
|
|||
Check the [examples](examples) to see Twitter and Tumblr requests in action. |
|||
|
|||
### Concepts |
|||
|
|||
An `Endpoint` groups an OAuth provider's token and authorization URL endpoints.Endpoints for common providers are provided in subpackages. |
|||
|
|||
A `Config` stores a consumer application's consumer key and secret, the registered callback URL, and the `Endpoint` to which the consumer is registered. It provides OAuth1 authorization flow methods. |
|||
|
|||
An OAuth1 `Token` is an access token which can be used to make signed requests on behalf of a user. See [Authorized Requests](#Authorized Requests) for details. |
|||
|
|||
If you've used the [golang.org/x/oauth2](https://godoc.org/golang.org/x/oauth2) package for OAuth2 before, this organization should be familiar. |
|||
|
|||
## Contributing |
|||
|
|||
See the [Contributing Guide](https://gist.github.com/dghubble/be682c123727f70bcfe7). |
|||
|
|||
## License |
|||
|
|||
[MIT License](LICENSE) |
@ -0,0 +1,265 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/rand" |
|||
"encoding/base64" |
|||
"fmt" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
"sort" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
) |
|||
|
|||
const ( |
|||
authorizationHeaderParam = "Authorization" |
|||
authorizationPrefix = "OAuth " // trailing space is intentional
|
|||
oauthConsumerKeyParam = "oauth_consumer_key" |
|||
oauthNonceParam = "oauth_nonce" |
|||
oauthSignatureParam = "oauth_signature" |
|||
oauthSignatureMethodParam = "oauth_signature_method" |
|||
oauthTimestampParam = "oauth_timestamp" |
|||
oauthTokenParam = "oauth_token" |
|||
oauthVersionParam = "oauth_version" |
|||
oauthCallbackParam = "oauth_callback" |
|||
oauthVerifierParam = "oauth_verifier" |
|||
defaultOauthVersion = "1.0" |
|||
contentType = "Content-Type" |
|||
formContentType = "application/x-www-form-urlencoded" |
|||
) |
|||
|
|||
// clock provides a interface for current time providers. A Clock can be used
|
|||
// in place of calling time.Now() directly.
|
|||
type clock interface { |
|||
Now() time.Time |
|||
} |
|||
|
|||
// A noncer provides random nonce strings.
|
|||
type noncer interface { |
|||
Nonce() string |
|||
} |
|||
|
|||
// auther adds an "OAuth" Authorization header field to requests.
|
|||
type auther struct { |
|||
config *Config |
|||
clock clock |
|||
noncer noncer |
|||
} |
|||
|
|||
func newAuther(config *Config) *auther { |
|||
return &auther{ |
|||
config: config, |
|||
} |
|||
} |
|||
|
|||
// setRequestTokenAuthHeader adds the OAuth1 header for the request token
|
|||
// request (temporary credential) according to RFC 5849 2.1.
|
|||
func (a *auther) setRequestTokenAuthHeader(req *http.Request) error { |
|||
oauthParams := a.commonOAuthParams() |
|||
oauthParams[oauthCallbackParam] = a.config.CallbackURL |
|||
params, err := collectParameters(req, oauthParams) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
signatureBase := signatureBase(req, params) |
|||
signature, err := a.signer().Sign("", signatureBase) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
oauthParams[oauthSignatureParam] = signature |
|||
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams)) |
|||
return nil |
|||
} |
|||
|
|||
// setAccessTokenAuthHeader sets the OAuth1 header for the access token request
|
|||
// (token credential) according to RFC 5849 2.3.
|
|||
func (a *auther) setAccessTokenAuthHeader(req *http.Request, requestToken, requestSecret, verifier string) error { |
|||
oauthParams := a.commonOAuthParams() |
|||
oauthParams[oauthTokenParam] = requestToken |
|||
oauthParams[oauthVerifierParam] = verifier |
|||
params, err := collectParameters(req, oauthParams) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
signatureBase := signatureBase(req, params) |
|||
signature, err := a.signer().Sign(requestSecret, signatureBase) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
oauthParams[oauthSignatureParam] = signature |
|||
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams)) |
|||
return nil |
|||
} |
|||
|
|||
// setRequestAuthHeader sets the OAuth1 header for making authenticated
|
|||
// requests with an AccessToken (token credential) according to RFC 5849 3.1.
|
|||
func (a *auther) setRequestAuthHeader(req *http.Request, accessToken *Token) error { |
|||
oauthParams := a.commonOAuthParams() |
|||
oauthParams[oauthTokenParam] = accessToken.Token |
|||
params, err := collectParameters(req, oauthParams) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
signatureBase := signatureBase(req, params) |
|||
signature, err := a.signer().Sign(accessToken.TokenSecret, signatureBase) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
oauthParams[oauthSignatureParam] = signature |
|||
req.Header.Set(authorizationHeaderParam, authHeaderValue(oauthParams)) |
|||
return nil |
|||
} |
|||
|
|||
// commonOAuthParams returns a map of the common OAuth1 protocol parameters,
|
|||
// excluding the oauth_signature parameter.
|
|||
func (a *auther) commonOAuthParams() map[string]string { |
|||
return map[string]string{ |
|||
oauthConsumerKeyParam: a.config.ConsumerKey, |
|||
oauthSignatureMethodParam: a.signer().Name(), |
|||
oauthTimestampParam: strconv.FormatInt(a.epoch(), 10), |
|||
oauthNonceParam: a.nonce(), |
|||
oauthVersionParam: defaultOauthVersion, |
|||
} |
|||
} |
|||
|
|||
// Returns a base64 encoded random 32 byte string.
|
|||
func (a *auther) nonce() string { |
|||
if a.noncer != nil { |
|||
return a.noncer.Nonce() |
|||
} |
|||
b := make([]byte, 32) |
|||
rand.Read(b) |
|||
return base64.StdEncoding.EncodeToString(b) |
|||
} |
|||
|
|||
// Returns the Unix epoch seconds.
|
|||
func (a *auther) epoch() int64 { |
|||
if a.clock != nil { |
|||
return a.clock.Now().Unix() |
|||
} |
|||
return time.Now().Unix() |
|||
} |
|||
|
|||
// Returns the Config's Signer or the default Signer.
|
|||
func (a *auther) signer() Signer { |
|||
if a.config.Signer != nil { |
|||
return a.config.Signer |
|||
} |
|||
return &HMACSigner{ConsumerSecret: a.config.ConsumerSecret} |
|||
} |
|||
|
|||
// authHeaderValue formats OAuth parameters according to RFC 5849 3.5.1. OAuth
|
|||
// params are percent encoded, sorted by key (for testability), and joined by
|
|||
// "=" into pairs. Pairs are joined with a ", " comma separator into a header
|
|||
// string.
|
|||
// The given OAuth params should include the "oauth_signature" key.
|
|||
func authHeaderValue(oauthParams map[string]string) string { |
|||
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`) |
|||
return authorizationPrefix + strings.Join(pairs, ", ") |
|||
} |
|||
|
|||
// encodeParameters percent encodes parameter keys and values according to
|
|||
// RFC5849 3.6 and RFC3986 2.1 and returns a new map.
|
|||
func encodeParameters(params map[string]string) map[string]string { |
|||
encoded := map[string]string{} |
|||
for key, value := range params { |
|||
encoded[PercentEncode(key)] = PercentEncode(value) |
|||
} |
|||
return encoded |
|||
} |
|||
|
|||
// sortParameters sorts parameters by key and returns a slice of key/value
|
|||
// pairs formatted with the given format string (e.g. "%s=%s").
|
|||
func sortParameters(params map[string]string, format string) []string { |
|||
// sort by key
|
|||
keys := make([]string, len(params)) |
|||
i := 0 |
|||
for key := range params { |
|||
keys[i] = key |
|||
i++ |
|||
} |
|||
sort.Strings(keys) |
|||
// parameter join
|
|||
pairs := make([]string, len(params)) |
|||
for i, key := range keys { |
|||
pairs[i] = fmt.Sprintf(format, key, params[key]) |
|||
} |
|||
return pairs |
|||
} |
|||
|
|||
// collectParameters collects request parameters from the request query, OAuth
|
|||
// parameters (which should exclude oauth_signature), and the request body
|
|||
// provided the body is single part, form encoded, and the form content type
|
|||
// header is set. The returned map of collected parameter keys and values
|
|||
// follow RFC 5849 3.4.1.3, except duplicate parameters are not supported.
|
|||
func collectParameters(req *http.Request, oauthParams map[string]string) (map[string]string, error) { |
|||
// add oauth, query, and body parameters into params
|
|||
params := map[string]string{} |
|||
for key, value := range req.URL.Query() { |
|||
// most backends do not accept duplicate query keys
|
|||
params[key] = value[0] |
|||
} |
|||
if req.Body != nil && req.Header.Get(contentType) == formContentType { |
|||
// reads data to a []byte, draining req.Body
|
|||
b, err := ioutil.ReadAll(req.Body) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
values, err := url.ParseQuery(string(b)) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
for key, value := range values { |
|||
// not supporting params with duplicate keys
|
|||
params[key] = value[0] |
|||
} |
|||
// reinitialize Body with ReadCloser over the []byte
|
|||
req.Body = ioutil.NopCloser(bytes.NewReader(b)) |
|||
} |
|||
for key, value := range oauthParams { |
|||
params[key] = value |
|||
} |
|||
return params, nil |
|||
} |
|||
|
|||
// signatureBase combines the uppercase request method, percent encoded base
|
|||
// string URI, and normalizes the request parameters int a parameter string.
|
|||
// Returns the OAuth1 signature base string according to RFC5849 3.4.1.
|
|||
func signatureBase(req *http.Request, params map[string]string) string { |
|||
method := strings.ToUpper(req.Method) |
|||
baseURL := baseURI(req) |
|||
parameterString := normalizedParameterString(params) |
|||
// signature base string constructed accoding to 3.4.1.1
|
|||
baseParts := []string{method, PercentEncode(baseURL), PercentEncode(parameterString)} |
|||
return strings.Join(baseParts, "&") |
|||
} |
|||
|
|||
// baseURI returns the base string URI of a request according to RFC 5849
|
|||
// 3.4.1.2. The scheme and host are lowercased, the port is dropped if it
|
|||
// is 80 or 443, and the path minus query parameters is included.
|
|||
func baseURI(req *http.Request) string { |
|||
scheme := strings.ToLower(req.URL.Scheme) |
|||
host := strings.ToLower(req.URL.Host) |
|||
if hostPort := strings.Split(host, ":"); len(hostPort) == 2 && (hostPort[1] == "80" || hostPort[1] == "443") { |
|||
host = hostPort[0] |
|||
} |
|||
// TODO: use req.URL.EscapedPath() once Go 1.5 is more generally adopted
|
|||
// For now, hacky workaround accomplishes the same internal escaping mode
|
|||
// escape(u.Path, encodePath) for proper compliance with the OAuth1 spec.
|
|||
path := req.URL.Path |
|||
if path != "" { |
|||
path = strings.Split(req.URL.RequestURI(), "?")[0] |
|||
} |
|||
return fmt.Sprintf("%v://%v%v", scheme, host, path) |
|||
} |
|||
|
|||
// parameterString normalizes collected OAuth parameters (which should exclude
|
|||
// oauth_signature) into a parameter string as defined in RFC 5894 3.4.1.3.2.
|
|||
// The parameters are encoded, sorted by key, keys and values joined with "&",
|
|||
// and pairs joined with "=" (e.g. foo=bar&q=gopher).
|
|||
func normalizedParameterString(params map[string]string) string { |
|||
return strings.Join(sortParameters(encodeParameters(params), "%s=%s"), "&") |
|||
} |
@ -0,0 +1,244 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestCommonOAuthParams(t *testing.T) { |
|||
config := &Config{ConsumerKey: "some_consumer_key"} |
|||
auther := &auther{config, &fixedClock{time.Unix(50037133, 0)}, &fixedNoncer{"some_nonce"}} |
|||
expectedParams := map[string]string{ |
|||
"oauth_consumer_key": "some_consumer_key", |
|||
"oauth_signature_method": "HMAC-SHA1", |
|||
"oauth_timestamp": "50037133", |
|||
"oauth_nonce": "some_nonce", |
|||
"oauth_version": "1.0", |
|||
} |
|||
assert.Equal(t, expectedParams, auther.commonOAuthParams()) |
|||
} |
|||
|
|||
func TestNonce(t *testing.T) { |
|||
auther := &auther{} |
|||
nonce := auther.nonce() |
|||
// assert that 32 bytes (256 bites) become 44 bytes since a base64 byte
|
|||
// zeros the 2 high bits. 3 bytes convert to 4 base64 bytes, 40 base64 bytes
|
|||
// represent the first 30 of 32 bytes, = padding adds another 4 byte group.
|
|||
// base64 bytes = 4 * floor(bytes/3) + 4
|
|||
assert.Equal(t, 44, len([]byte(nonce))) |
|||
} |
|||
|
|||
func TestEpoch(t *testing.T) { |
|||
a := &auther{} |
|||
// assert that a real time is used by default
|
|||
assert.InEpsilon(t, time.Now().Unix(), a.epoch(), 1) |
|||
// assert that the fixed clock can be used for testing
|
|||
a = &auther{clock: &fixedClock{time.Unix(50037133, 0)}} |
|||
assert.Equal(t, int64(50037133), a.epoch()) |
|||
} |
|||
|
|||
func TestSigner_Default(t *testing.T) { |
|||
config := &Config{ConsumerSecret: "consumer_secret"} |
|||
a := newAuther(config) |
|||
// echo -n "hello world" | openssl dgst -sha1 -hmac "consumer_secret&token_secret" -binary | base64
|
|||
expectedSignature := "BE0uILOruKfSXd4UzYlLJDfOq08=" |
|||
// assert that the default signer produces the expected HMAC-SHA1 digest
|
|||
method := a.signer().Name() |
|||
digest, err := a.signer().Sign("token_secret", "hello world") |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, "HMAC-SHA1", method) |
|||
assert.Equal(t, expectedSignature, digest) |
|||
} |
|||
|
|||
type identitySigner struct{} |
|||
|
|||
func (s *identitySigner) Name() string { |
|||
return "identity" |
|||
} |
|||
|
|||
func (s *identitySigner) Sign(tokenSecret, message string) (string, error) { |
|||
return message, nil |
|||
} |
|||
|
|||
func TestSigner_Custom(t *testing.T) { |
|||
config := &Config{ |
|||
ConsumerSecret: "consumer_secret", |
|||
Signer: &identitySigner{}, |
|||
} |
|||
a := newAuther(config) |
|||
// assert that the custom signer is used
|
|||
method := a.signer().Name() |
|||
digest, err := a.signer().Sign("secret", "hello world") |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, "identity", method) |
|||
assert.Equal(t, "hello world", digest) |
|||
} |
|||
|
|||
func TestAuthHeaderValue(t *testing.T) { |
|||
cases := []struct { |
|||
params map[string]string |
|||
authHeader string |
|||
}{ |
|||
{map[string]string{}, "OAuth "}, |
|||
{map[string]string{"a": "b"}, `OAuth a="b"`}, |
|||
{map[string]string{"a": "b", "c": "d", "e": "f", "1": "2"}, `OAuth 1="2", a="b", c="d", e="f"`}, |
|||
{map[string]string{"/= +doencode": "/= +doencode"}, `OAuth %2F%3D%20%2Bdoencode="%2F%3D%20%2Bdoencode"`}, |
|||
{map[string]string{"-._~dontencode": "-._~dontencode"}, `OAuth -._~dontencode="-._~dontencode"`}, |
|||
} |
|||
for _, c := range cases { |
|||
assert.Equal(t, c.authHeader, authHeaderValue(c.params)) |
|||
} |
|||
} |
|||
|
|||
func TestEncodeParameters(t *testing.T) { |
|||
input := map[string]string{ |
|||
"a": "Dogs, Cats & Mice", |
|||
"☃": "snowman", |
|||
"ル": "ル", |
|||
} |
|||
expected := map[string]string{ |
|||
"a": "Dogs%2C%20Cats%20%26%20Mice", |
|||
"%E2%98%83": "snowman", |
|||
"%E3%83%AB": "%E3%83%AB", |
|||
} |
|||
assert.Equal(t, expected, encodeParameters(input)) |
|||
} |
|||
|
|||
func TestSortParameters(t *testing.T) { |
|||
input := map[string]string{ |
|||
".": "ape", |
|||
"5.6": "bat", |
|||
"rsa": "cat", |
|||
"%20": "dog", |
|||
"%E3%83%AB": "eel", |
|||
"dup": "fox", |
|||
//"dup": "fix", // duplicate keys not supported
|
|||
} |
|||
expected := []string{ |
|||
"%20=dog", |
|||
"%E3%83%AB=eel", |
|||
".=ape", |
|||
"5.6=bat", |
|||
"dup=fox", |
|||
"rsa=cat", |
|||
} |
|||
assert.Equal(t, expected, sortParameters(input, "%s=%s")) |
|||
} |
|||
|
|||
func TestCollectParameters(t *testing.T) { |
|||
// example from RFC 5849 3.4.1.3.1
|
|||
oauthParams := map[string]string{ |
|||
"oauth_token": "kkk9d7dh3k39sjv7", |
|||
"oauth_consumer_key": "9djdj82h48djs9d2", |
|||
"oauth_signature_method": "HMAC-SHA1", |
|||
"oauth_timestamp": "137131201", |
|||
"oauth_nonce": "7d8f3e4a", |
|||
} |
|||
values := url.Values{} |
|||
values.Add("c2", "") |
|||
values.Add("plus", "2 q") // duplicate keys not supported, a3 -> plus
|
|||
req, err := http.NewRequest("POST", "/request?b5=%3D%253D&a3=a&c%40=&a2=r%20b", strings.NewReader(values.Encode())) |
|||
assert.Nil(t, err) |
|||
req.Header.Set(contentType, formContentType) |
|||
params, err := collectParameters(req, oauthParams) |
|||
// assert parameters were collected from oauthParams, the query, and form body
|
|||
expected := map[string]string{ |
|||
"b5": "=%3D", |
|||
"a3": "a", |
|||
"c@": "", |
|||
"a2": "r b", |
|||
"oauth_token": "kkk9d7dh3k39sjv7", |
|||
"oauth_consumer_key": "9djdj82h48djs9d2", |
|||
"oauth_signature_method": "HMAC-SHA1", |
|||
"oauth_timestamp": "137131201", |
|||
"oauth_nonce": "7d8f3e4a", |
|||
"c2": "", |
|||
"plus": "2 q", |
|||
} |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expected, params) |
|||
// RFC 5849 3.4.1.3.1 requires a {"a3"="2 q"} be form encoded to "a3=2+q" in
|
|||
// the application/x-www-form-urlencoded body. The parameter "2+q" should be
|
|||
// read as "2 q" and percent encoded to "2%20q".
|
|||
// In Go, data is form encoded by calling Encode on url.Values{} (URL
|
|||
// encoding) and decoded with url.ParseQuery to url.Values. So the encoding
|
|||
// of "2 q" to "2+q" and decoding back to "2 q" is handled and then params
|
|||
// are percent encoded.
|
|||
// http://golang.org/src/net/http/client.go#L496
|
|||
// http://golang.org/src/net/http/request.go#L837
|
|||
} |
|||
|
|||
func TestSignatureBase(t *testing.T) { |
|||
reqA, err := http.NewRequest("get", "HTTPS://HELLO.IO?q=test", nil) |
|||
assert.Nil(t, err) |
|||
reqB, err := http.NewRequest("POST", "http://hello.io:8080", nil) |
|||
assert.Nil(t, err) |
|||
cases := []struct { |
|||
req *http.Request |
|||
params map[string]string |
|||
signatureBase string |
|||
}{ |
|||
{reqA, map[string]string{"a": "b", "c": "d"}, "GET&https%3A%2F%2Fhello.io&a%3Db%26c%3Dd"}, |
|||
{reqB, map[string]string{"a": "b"}, "POST&http%3A%2F%2Fhello.io%3A8080&a%3Db"}, |
|||
} |
|||
// assert that method is uppercased, base uri rules applied, queries added, joined by &
|
|||
for _, c := range cases { |
|||
base := signatureBase(c.req, c.params) |
|||
assert.Equal(t, c.signatureBase, base) |
|||
} |
|||
} |
|||
|
|||
func TestBaseURI(t *testing.T) { |
|||
reqA, err := http.NewRequest("GET", "HTTP://EXAMPLE.COM:80/r%20v/X?id=123", nil) |
|||
assert.Nil(t, err) |
|||
reqB, err := http.NewRequest("POST", "https://www.example.net:8080/?q=1", nil) |
|||
assert.Nil(t, err) |
|||
reqC, err := http.NewRequest("POST", "https://example.com:443", nil) |
|||
cases := []struct { |
|||
req *http.Request |
|||
baseURI string |
|||
}{ |
|||
{reqA, "http://example.com/r%20v/X"}, |
|||
{reqB, "https://www.example.net:8080/"}, |
|||
{reqC, "https://example.com"}, |
|||
} |
|||
for _, c := range cases { |
|||
baseURI := baseURI(c.req) |
|||
assert.Equal(t, c.baseURI, baseURI) |
|||
} |
|||
} |
|||
|
|||
func TestNormalizedParameterString(t *testing.T) { |
|||
simple := map[string]string{ |
|||
"a": "b & c", |
|||
"☃": "snowman", |
|||
} |
|||
rfcExample := map[string]string{ |
|||
"b5": "=%3D", |
|||
"a3": "a", |
|||
"c@": "", |
|||
"a2": "r b", |
|||
"oauth_token": "kkk9d7dh3k39sjv7", |
|||
"oauth_consumer_key": "9djdj82h48djs9d2", |
|||
"oauth_signature_method": "HMAC-SHA1", |
|||
"oauth_timestamp": "137131201", |
|||
"oauth_nonce": "7d8f3e4a", |
|||
"c2": "", |
|||
"plus": "2 q", |
|||
} |
|||
cases := []struct { |
|||
params map[string]string |
|||
parameterStr string |
|||
}{ |
|||
{simple, "%E2%98%83=snowman&a=b%20%26%20c"}, |
|||
{rfcExample, "a2=r%20b&a3=a&b5=%3D%253D&c%40=&c2=&oauth_consumer_key=9djdj82h48djs9d2&oauth_nonce=7d8f3e4a&oauth_signature_method=HMAC-SHA1&oauth_timestamp=137131201&oauth_token=kkk9d7dh3k39sjv7&plus=2%20q"}, |
|||
} |
|||
for _, c := range cases { |
|||
assert.Equal(t, c.parameterStr, normalizedParameterString(c.params)) |
|||
} |
|||
} |
@ -0,0 +1,165 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"errors" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
const ( |
|||
oauthTokenSecretParam = "oauth_token_secret" |
|||
oauthCallbackConfirmedParam = "oauth_callback_confirmed" |
|||
) |
|||
|
|||
// Config represents an OAuth1 consumer's (client's) key and secret, the
|
|||
// callback URL, and the provider Endpoint to which the consumer corresponds.
|
|||
type Config struct { |
|||
// Consumer Key (Client Identifier)
|
|||
ConsumerKey string |
|||
// Consumer Secret (Client Shared-Secret)
|
|||
ConsumerSecret string |
|||
// Callback URL
|
|||
CallbackURL string |
|||
// Provider Endpoint specifying OAuth1 endpoint URLs
|
|||
Endpoint Endpoint |
|||
// OAuth1 Signer (defaults to HMAC-SHA1)
|
|||
Signer Signer |
|||
} |
|||
|
|||
// NewConfig returns a new Config with the given consumer key and secret.
|
|||
func NewConfig(consumerKey, consumerSecret string) *Config { |
|||
return &Config{ |
|||
ConsumerKey: consumerKey, |
|||
ConsumerSecret: consumerSecret, |
|||
} |
|||
} |
|||
|
|||
// Client returns an HTTP client which uses the provided ctx and access Token.
|
|||
func (c *Config) Client(ctx context.Context, t *Token) *http.Client { |
|||
return NewClient(ctx, c, t) |
|||
} |
|||
|
|||
// NewClient returns a new http Client which signs requests via OAuth1.
|
|||
func NewClient(ctx context.Context, config *Config, token *Token) *http.Client { |
|||
transport := &Transport{ |
|||
Base: contextTransport(ctx), |
|||
source: StaticTokenSource(token), |
|||
auther: newAuther(config), |
|||
} |
|||
return &http.Client{Transport: transport} |
|||
} |
|||
|
|||
// RequestToken obtains a Request token and secret (temporary credential) by
|
|||
// POSTing a request (with oauth_callback in the auth header) to the Endpoint
|
|||
// RequestTokenURL. The response body form is validated to ensure
|
|||
// oauth_callback_confirmed is true. Returns the request token and secret
|
|||
// (temporary credentials).
|
|||
// See RFC 5849 2.1 Temporary Credentials.
|
|||
func (c *Config) RequestToken() (requestToken, requestSecret string, err error) { |
|||
req, err := http.NewRequest("POST", c.Endpoint.RequestTokenURL, nil) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
err = newAuther(c).setRequestTokenAuthHeader(req) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
resp, err := http.DefaultClient.Do(req) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
// when err is nil, resp contains a non-nil resp.Body which must be closed
|
|||
defer resp.Body.Close() |
|||
body, err := ioutil.ReadAll(resp.Body) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
// ParseQuery to decode URL-encoded application/x-www-form-urlencoded body
|
|||
values, err := url.ParseQuery(string(body)) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
if values.Get(oauthCallbackConfirmedParam) != "true" { |
|||
return "", "", errors.New("oauth1: oauth_callback_confirmed was not true") |
|||
} |
|||
requestToken = values.Get(oauthTokenParam) |
|||
requestSecret = values.Get(oauthTokenSecretParam) |
|||
if requestToken == "" || requestSecret == "" { |
|||
return "", "", errors.New("oauth1: Response missing oauth_token or oauth_token_secret") |
|||
} |
|||
return requestToken, requestSecret, nil |
|||
} |
|||
|
|||
// AuthorizationURL accepts a request token and returns the *url.URL to the
|
|||
// Endpoint's authorization page that asks the user (resource owner) for to
|
|||
// authorize the consumer to act on his/her/its behalf.
|
|||
// See RFC 5849 2.2 Resource Owner Authorization.
|
|||
func (c *Config) AuthorizationURL(requestToken string) (*url.URL, error) { |
|||
authorizationURL, err := url.Parse(c.Endpoint.AuthorizeURL) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
values := authorizationURL.Query() |
|||
values.Add(oauthTokenParam, requestToken) |
|||
authorizationURL.RawQuery = values.Encode() |
|||
return authorizationURL, nil |
|||
} |
|||
|
|||
// ParseAuthorizationCallback parses an OAuth1 authorization callback request
|
|||
// from a provider server. The oauth_token and oauth_verifier parameters are
|
|||
// parsed to return the request token from earlier in the flow and the
|
|||
// verifier string.
|
|||
// See RFC 5849 2.2 Resource Owner Authorization.
|
|||
func ParseAuthorizationCallback(req *http.Request) (requestToken, verifier string, err error) { |
|||
// parse the raw query from the URL into req.Form
|
|||
err = req.ParseForm() |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
requestToken = req.Form.Get(oauthTokenParam) |
|||
verifier = req.Form.Get(oauthVerifierParam) |
|||
if requestToken == "" || verifier == "" { |
|||
return "", "", errors.New("oauth1: Request missing oauth_token or oauth_verifier") |
|||
} |
|||
return requestToken, verifier, nil |
|||
} |
|||
|
|||
// AccessToken obtains an access token (token credential) by POSTing a
|
|||
// request (with oauth_token and oauth_verifier in the auth header) to the
|
|||
// Endpoint AccessTokenURL. Returns the access token and secret (token
|
|||
// credentials).
|
|||
// See RFC 5849 2.3 Token Credentials.
|
|||
func (c *Config) AccessToken(requestToken, requestSecret, verifier string) (accessToken, accessSecret string, err error) { |
|||
req, err := http.NewRequest("POST", c.Endpoint.AccessTokenURL, nil) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
err = newAuther(c).setAccessTokenAuthHeader(req, requestToken, requestSecret, verifier) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
resp, err := http.DefaultClient.Do(req) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
// when err is nil, resp contains a non-nil resp.Body which must be closed
|
|||
defer resp.Body.Close() |
|||
body, err := ioutil.ReadAll(resp.Body) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
// ParseQuery to decode URL-encoded application/x-www-form-urlencoded body
|
|||
values, err := url.ParseQuery(string(body)) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
accessToken = values.Get(oauthTokenParam) |
|||
accessSecret = values.Get(oauthTokenSecretParam) |
|||
if accessToken == "" || accessSecret == "" { |
|||
return "", "", errors.New("oauth1: Response missing oauth_token or oauth_token_secret") |
|||
} |
|||
return accessToken, accessSecret, nil |
|||
} |
@ -0,0 +1,342 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"net/url" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
const expectedVerifier = "some_verifier" |
|||
|
|||
func TestNewConfig(t *testing.T) { |
|||
expectedConsumerKey := "consumer_key" |
|||
expectedConsumerSecret := "consumer_secret" |
|||
config := NewConfig(expectedConsumerKey, expectedConsumerSecret) |
|||
assert.Equal(t, expectedConsumerKey, config.ConsumerKey) |
|||
assert.Equal(t, expectedConsumerSecret, config.ConsumerSecret) |
|||
} |
|||
|
|||
func TestNewClient(t *testing.T) { |
|||
expectedToken := "access_token" |
|||
expectedConsumerKey := "consumer_key" |
|||
config := NewConfig(expectedConsumerKey, "consumer_secret") |
|||
token := NewToken(expectedToken, "access_secret") |
|||
client := config.Client(NoContext, token) |
|||
|
|||
server := newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "GET", req.Method) |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) |
|||
assert.Equal(t, expectedToken, params[oauthTokenParam]) |
|||
assert.Equal(t, expectedConsumerKey, params[oauthConsumerKeyParam]) |
|||
}) |
|||
defer server.Close() |
|||
client.Get(server.URL) |
|||
} |
|||
|
|||
func TestNewClient_DefaultTransport(t *testing.T) { |
|||
client := NewClient(NoContext, NewConfig("t", "s"), NewToken("t", "s")) |
|||
// assert that the client uses the DefaultTransport
|
|||
transport, ok := client.Transport.(*Transport) |
|||
assert.True(t, ok) |
|||
assert.Equal(t, http.DefaultTransport, transport.base()) |
|||
} |
|||
|
|||
func TestNewClient_ContextClientTransport(t *testing.T) { |
|||
baseTransport := &http.Transport{} |
|||
baseClient := &http.Client{Transport: baseTransport} |
|||
ctx := context.WithValue(NoContext, HTTPClient, baseClient) |
|||
client := NewClient(ctx, NewConfig("t", "s"), NewToken("t", "s")) |
|||
// assert that the client uses the ctx client's Transport as its base RoundTripper
|
|||
transport, ok := client.Transport.(*Transport) |
|||
assert.True(t, ok) |
|||
assert.Equal(t, baseTransport, transport.base()) |
|||
} |
|||
|
|||
// newRequestTokenServer returns a new httptest.Server for an OAuth1 provider
|
|||
// request token endpoint.
|
|||
func newRequestTokenServer(t *testing.T, data url.Values) *httptest.Server { |
|||
return newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "POST", req.Method) |
|||
assert.NotEmpty(t, req.Header.Get("Authorization")) |
|||
w.Header().Set(contentType, formContentType) |
|||
w.Write([]byte(data.Encode())) |
|||
}) |
|||
} |
|||
|
|||
// newAccessTokenServer returns a new httptest.Server for an OAuth1 provider
|
|||
// access token endpoint.
|
|||
func newAccessTokenServer(t *testing.T, data url.Values) *httptest.Server { |
|||
return newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "POST", req.Method) |
|||
assert.NotEmpty(t, req.Header.Get("Authorization")) |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) |
|||
assert.Equal(t, expectedVerifier, params[oauthVerifierParam]) |
|||
w.Header().Set(contentType, formContentType) |
|||
w.Write([]byte(data.Encode())) |
|||
}) |
|||
} |
|||
|
|||
// newUnparseableBodyServer returns a new httptest.Server which writes
|
|||
// responses with bodies that error when parsed by url.ParseQuery.
|
|||
func newUnparseableBodyServer() *httptest.Server { |
|||
return newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
w.Header().Set(contentType, formContentType) |
|||
// url.ParseQuery will error, https://golang.org/src/net/url/url_test.go#L1107
|
|||
w.Write([]byte("%gh&%ij")) |
|||
}) |
|||
} |
|||
|
|||
func TestConfigRequestToken(t *testing.T) { |
|||
expectedToken := "reqest_token" |
|||
expectedSecret := "request_secret" |
|||
data := url.Values{} |
|||
data.Add("oauth_token", expectedToken) |
|||
data.Add("oauth_token_secret", expectedSecret) |
|||
data.Add("oauth_callback_confirmed", "true") |
|||
server := newRequestTokenServer(t, data) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: server.URL, |
|||
}, |
|||
} |
|||
requestToken, requestSecret, err := config.RequestToken() |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedToken, requestToken) |
|||
assert.Equal(t, expectedSecret, requestSecret) |
|||
} |
|||
|
|||
func TestConfigRequestToken_InvalidRequestTokenURL(t *testing.T) { |
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: "http://wrong.com/oauth/request_token", |
|||
}, |
|||
} |
|||
requestToken, requestSecret, err := config.RequestToken() |
|||
assert.NotNil(t, err) |
|||
assert.Equal(t, "", requestToken) |
|||
assert.Equal(t, "", requestSecret) |
|||
} |
|||
|
|||
func TestConfigRequestToken_CallbackNotConfirmed(t *testing.T) { |
|||
data := url.Values{} |
|||
data.Add("oauth_callback_confirmed", "false") |
|||
server := newRequestTokenServer(t, data) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: server.URL, |
|||
}, |
|||
} |
|||
requestToken, requestSecret, err := config.RequestToken() |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "oauth1: oauth_callback_confirmed was not true", err.Error()) |
|||
} |
|||
assert.Equal(t, "", requestToken) |
|||
assert.Equal(t, "", requestSecret) |
|||
} |
|||
|
|||
func TestConfigRequestToken_CannotParseBody(t *testing.T) { |
|||
server := newUnparseableBodyServer() |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: server.URL, |
|||
}, |
|||
} |
|||
requestToken, requestSecret, err := config.RequestToken() |
|||
if assert.Error(t, err) { |
|||
assert.Contains(t, err.Error(), "invalid URL escape") |
|||
} |
|||
assert.Equal(t, "", requestToken) |
|||
assert.Equal(t, "", requestSecret) |
|||
} |
|||
|
|||
func TestConfigRequestToken_MissingTokenOrSecret(t *testing.T) { |
|||
data := url.Values{} |
|||
data.Add("oauth_token", "any_token") |
|||
data.Add("oauth_callback_confirmed", "true") |
|||
server := newRequestTokenServer(t, data) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: server.URL, |
|||
}, |
|||
} |
|||
requestToken, requestSecret, err := config.RequestToken() |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error()) |
|||
} |
|||
assert.Equal(t, "", requestToken) |
|||
assert.Equal(t, "", requestSecret) |
|||
} |
|||
|
|||
func TestAuthorizationURL(t *testing.T) { |
|||
expectedURL := "https://api.example.com/oauth/authorize?oauth_token=a%2Frequest_token" |
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AuthorizeURL: "https://api.example.com/oauth/authorize", |
|||
}, |
|||
} |
|||
url, err := config.AuthorizationURL("a/request_token") |
|||
assert.Nil(t, err) |
|||
if assert.NotNil(t, url) { |
|||
assert.Equal(t, expectedURL, url.String()) |
|||
} |
|||
} |
|||
|
|||
func TestAuthorizationURL_CannotParseAuthorizeURL(t *testing.T) { |
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AuthorizeURL: "%gh&%ij", |
|||
}, |
|||
} |
|||
url, err := config.AuthorizationURL("any_request_token") |
|||
assert.Nil(t, url) |
|||
if assert.Error(t, err) { |
|||
assert.Contains(t, err.Error(), "parse") |
|||
assert.Contains(t, err.Error(), "invalid URL") |
|||
} |
|||
} |
|||
|
|||
func TestConfigAccessToken(t *testing.T) { |
|||
expectedToken := "access_token" |
|||
expectedSecret := "access_secret" |
|||
data := url.Values{} |
|||
data.Add("oauth_token", expectedToken) |
|||
data.Add("oauth_token_secret", expectedSecret) |
|||
server := newAccessTokenServer(t, data) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AccessTokenURL: server.URL, |
|||
}, |
|||
} |
|||
accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedToken, accessToken) |
|||
assert.Equal(t, expectedSecret, accessSecret) |
|||
} |
|||
|
|||
func TestConfigAccessToken_InvalidAccessTokenURL(t *testing.T) { |
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AccessTokenURL: "http://wrong.com/oauth/access_token", |
|||
}, |
|||
} |
|||
accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier") |
|||
assert.NotNil(t, err) |
|||
assert.Equal(t, "", accessToken) |
|||
assert.Equal(t, "", accessSecret) |
|||
} |
|||
|
|||
func TestConfigAccessToken_CannotParseBody(t *testing.T) { |
|||
server := newUnparseableBodyServer() |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AccessTokenURL: server.URL, |
|||
}, |
|||
} |
|||
accessToken, accessSecret, err := config.AccessToken("any_token", "any_secret", "any_verifier") |
|||
if assert.Error(t, err) { |
|||
assert.Contains(t, err.Error(), "invalid URL escape") |
|||
} |
|||
assert.Equal(t, "", accessToken) |
|||
assert.Equal(t, "", accessSecret) |
|||
} |
|||
|
|||
func TestConfigAccessToken_MissingTokenOrSecret(t *testing.T) { |
|||
data := url.Values{} |
|||
data.Add("oauth_token", "any_token") |
|||
server := newAccessTokenServer(t, data) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
Endpoint: Endpoint{ |
|||
AccessTokenURL: server.URL, |
|||
}, |
|||
} |
|||
accessToken, accessSecret, err := config.AccessToken("request_token", "request_secret", expectedVerifier) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "oauth1: Response missing oauth_token or oauth_token_secret", err.Error()) |
|||
} |
|||
assert.Equal(t, "", accessToken) |
|||
assert.Equal(t, "", accessSecret) |
|||
} |
|||
|
|||
func TestParseAuthorizationCallback_GET(t *testing.T) { |
|||
expectedToken := "token" |
|||
expectedVerifier := "verifier" |
|||
server := newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "GET", req.Method) |
|||
// logic under test
|
|||
requestToken, verifier, err := ParseAuthorizationCallback(req) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedToken, requestToken) |
|||
assert.Equal(t, expectedVerifier, verifier) |
|||
}) |
|||
defer server.Close() |
|||
|
|||
// OAuth1 provider calls callback url
|
|||
url, err := url.Parse(server.URL) |
|||
assert.Nil(t, err) |
|||
query := url.Query() |
|||
query.Add("oauth_token", expectedToken) |
|||
query.Add("oauth_verifier", expectedVerifier) |
|||
url.RawQuery = query.Encode() |
|||
http.Get(url.String()) |
|||
} |
|||
|
|||
func TestParseAuthorizationCallback_POST(t *testing.T) { |
|||
expectedToken := "token" |
|||
expectedVerifier := "verifier" |
|||
server := newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "POST", req.Method) |
|||
// logic under test
|
|||
requestToken, verifier, err := ParseAuthorizationCallback(req) |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedToken, requestToken) |
|||
assert.Equal(t, expectedVerifier, verifier) |
|||
}) |
|||
defer server.Close() |
|||
|
|||
// OAuth1 provider calls callback url
|
|||
form := url.Values{} |
|||
form.Add("oauth_token", expectedToken) |
|||
form.Add("oauth_verifier", expectedVerifier) |
|||
http.PostForm(server.URL, form) |
|||
} |
|||
|
|||
func TestParseAuthorizationCallback_MissingTokenOrVerifier(t *testing.T) { |
|||
server := newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
assert.Equal(t, "GET", req.Method) |
|||
// logic under test
|
|||
requestToken, verifier, err := ParseAuthorizationCallback(req) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "oauth1: Request missing oauth_token or oauth_verifier", err.Error()) |
|||
} |
|||
assert.Equal(t, "", requestToken) |
|||
assert.Equal(t, "", verifier) |
|||
}) |
|||
defer server.Close() |
|||
|
|||
// OAuth1 provider calls callback url
|
|||
url, err := url.Parse(server.URL) |
|||
assert.Nil(t, err) |
|||
query := url.Query() |
|||
query.Add("oauth_token", "any_token") |
|||
query.Add("oauth_verifier", "") // missing oauth_verifier
|
|||
url.RawQuery = query.Encode() |
|||
http.Get(url.String()) |
|||
} |
@ -0,0 +1,24 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"net/http" |
|||
|
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
type contextKey struct{} |
|||
|
|||
// HTTPClient is the context key to associate an *http.Client value with
|
|||
// a context.
|
|||
var HTTPClient contextKey |
|||
|
|||
// NoContext is the default context to use in most cases.
|
|||
var NoContext = context.TODO() |
|||
|
|||
// contextTransport gets the Transport from the context client or nil.
|
|||
func contextTransport(ctx context.Context) http.RoundTripper { |
|||
if client, ok := ctx.Value(HTTPClient).(*http.Client); ok { |
|||
return client.Transport |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,21 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
"golang.org/x/net/context" |
|||
) |
|||
|
|||
func TestContextTransport(t *testing.T) { |
|||
client := &http.Client{ |
|||
Transport: http.DefaultTransport, |
|||
} |
|||
ctx := context.WithValue(NoContext, HTTPClient, client) |
|||
assert.Equal(t, http.DefaultTransport, contextTransport(ctx)) |
|||
} |
|||
|
|||
func TestContextTransport_NoContextClient(t *testing.T) { |
|||
assert.Nil(t, contextTransport(NoContext)) |
|||
} |
@ -0,0 +1,97 @@ |
|||
/* |
|||
Package oauth1 is a Go implementation of the OAuth1 spec RFC 5849. |
|||
|
|||
It allows end-users to authorize a client (consumer) to access protected |
|||
resources on their behalf (e.g. login) and allows clients to make signed and |
|||
authorized requests on behalf of a user (e.g. API calls). |
|||
|
|||
It takes design cues from golang.org/x/oauth2, providing an http.Client which |
|||
handles request signing and authorization. |
|||
|
|||
Usage |
|||
|
|||
Package oauth1 implements the OAuth1 authorization flow and provides an |
|||
http.Client which can sign and authorize OAuth1 requests. |
|||
|
|||
To implement "Login with X", use the https://github.com/dghubble/gologin
|
|||
packages which provide login handlers for OAuth1 and OAuth2 providers. |
|||
|
|||
To call the Twitter, Digits, or Tumblr OAuth1 APIs, use the higher level Go API |
|||
clients. |
|||
|
|||
* https://github.com/dghubble/go-twitter
|
|||
* https://github.com/dghubble/go-digits
|
|||
* https://github.com/benfb/go-tumblr
|
|||
|
|||
Authorization Flow |
|||
|
|||
Perform the OAuth 1 authorization flow to ask a user to grant an application |
|||
access to his/her resources via an access token. |
|||
|
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
"github.com/dghubble/oauth1/twitter"" |
|||
) |
|||
... |
|||
|
|||
config := oauth1.Config{ |
|||
ConsumerKey: "consumerKey", |
|||
ConsumerSecret: "consumerSecret", |
|||
CallbackURL: "http://mysite.com/oauth/twitter/callback", |
|||
Endpoint: twitter.AuthorizeEndpoint, |
|||
} |
|||
|
|||
1. When a user performs an action (e.g. "Login with X" button calls "/login" |
|||
route) get an OAuth1 request token (temporary credentials). |
|||
|
|||
requestToken, requestSecret, err = config.RequestToken() |
|||
// handle err
|
|||
|
|||
2. Obtain authorization from the user by redirecting them to the OAuth1 |
|||
provider's authorization URL to grant the application access. |
|||
|
|||
authorizationURL, err := config.AuthorizationURL(requestToken) |
|||
// handle err
|
|||
http.Redirect(w, req, authorizationURL.String(), htt.StatusFound) |
|||
|
|||
Receive the callback from the OAuth1 provider in a handler. |
|||
|
|||
requestToken, verifier, err := oauth1.ParseAuthorizationCallback(req) |
|||
// handle err
|
|||
|
|||
3. Acquire the access token (token credentials) which can later be used |
|||
to make requests on behalf of the user. |
|||
|
|||
accessToken, accessSecret, err := config.AccessToken(requestToken, requestSecret, verifier) |
|||
// handle error
|
|||
token := NewToken(accessToken, accessSecret) |
|||
|
|||
Check the examples to see this authorization flow in action from the command |
|||
line, with Twitter PIN-based login and Tumblr login. |
|||
|
|||
Authorized Requests |
|||
|
|||
Use an access Token to make authorized requests on behalf of a user. |
|||
|
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
func main() { |
|||
config := oauth1.NewConfig("consumerKey", "consumerSecret") |
|||
token := oauth1.NewToken("token", "tokenSecret") |
|||
|
|||
// httpClient will automatically authorize http.Request's
|
|||
httpClient := config.Client(token) |
|||
|
|||
// example Twitter API request
|
|||
path := "https://api.twitter.com/1.1/statuses/home_timeline.json?count=2" |
|||
resp, _ := httpClient.Get(path) |
|||
defer resp.Body.Close() |
|||
body, _ := ioutil.ReadAll(resp.Body) |
|||
fmt.Printf("Raw Response Body:\n%v\n", string(body)) |
|||
} |
|||
|
|||
Check the examples to see Twitter and Tumblr requests in action. |
|||
*/ |
|||
package oauth1 |
@ -0,0 +1,13 @@ |
|||
// Package dropbox provides constants for using OAuth1 to access Dropbox.
|
|||
package dropbox |
|||
|
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
// Endpoint is Dropbox's OAuth 1 endpoint.
|
|||
var Endpoint = oauth1.Endpoint{ |
|||
RequestTokenURL: "https://api.dropbox.com/1/oauth/request_token", |
|||
AuthorizeURL: "https://api.dropbox.com/1/oauth/authorize", |
|||
AccessTokenURL: "https://api.dropbox.com/1/oauth/access_token", |
|||
} |
@ -0,0 +1,36 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
) |
|||
|
|||
// PercentEncode percent encodes a string according to RFC 3986 2.1.
|
|||
func PercentEncode(input string) string { |
|||
var buf bytes.Buffer |
|||
for _, b := range []byte(input) { |
|||
// if in unreserved set
|
|||
if shouldEscape(b) { |
|||
buf.Write([]byte(fmt.Sprintf("%%%02X", b))) |
|||
} else { |
|||
// do not escape, write byte as-is
|
|||
buf.WriteByte(b) |
|||
} |
|||
} |
|||
return buf.String() |
|||
} |
|||
|
|||
// shouldEscape returns false if the byte is an unreserved character that
|
|||
// should not be escaped and true otherwise, according to RFC 3986 2.1.
|
|||
func shouldEscape(c byte) bool { |
|||
// RFC3986 2.3 unreserved characters
|
|||
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { |
|||
return false |
|||
} |
|||
switch c { |
|||
case '-', '.', '_', '~': |
|||
return false |
|||
} |
|||
// all other bytes must be escaped
|
|||
return true |
|||
} |
@ -0,0 +1,27 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"testing" |
|||
) |
|||
|
|||
func TestPercentEncode(t *testing.T) { |
|||
cases := []struct { |
|||
input string |
|||
expected string |
|||
}{ |
|||
{" ", "%20"}, |
|||
{"%", "%25"}, |
|||
{"&", "%26"}, |
|||
{"-._", "-._"}, |
|||
{" /=+", "%20%2F%3D%2B"}, |
|||
{"Ladies + Gentlemen", "Ladies%20%2B%20Gentlemen"}, |
|||
{"An encoded string!", "An%20encoded%20string%21"}, |
|||
{"Dogs, Cats & Mice", "Dogs%2C%20Cats%20%26%20Mice"}, |
|||
{"☃", "%E2%98%83"}, |
|||
} |
|||
for _, c := range cases { |
|||
if output := PercentEncode(c.input); output != c.expected { |
|||
t.Errorf("expected %s, got %s", c.expected, output) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,12 @@ |
|||
package oauth1 |
|||
|
|||
// Endpoint represents an OAuth1 provider's (server's) request token,
|
|||
// owner authorization, and access token request URLs.
|
|||
type Endpoint struct { |
|||
// Request URL (Temporary Credential Request URI)
|
|||
RequestTokenURL string |
|||
// Authorize URL (Resource Owner Authorization URI)
|
|||
AuthorizeURL string |
|||
// Access Token URL (Token Request URI)
|
|||
AccessTokenURL string |
|||
} |
@ -0,0 +1,48 @@ |
|||
|
|||
# OAuth1 Examples |
|||
|
|||
## Twitter |
|||
|
|||
### Authorization Flow (PIN-based) |
|||
|
|||
An application can obtain a Twitter access `Token` for a user by requesting the user grant access via [3-legged](https://dev.twitter.com/oauth/3-legged) or [PIN-based](https://dev.twitter.com/oauth/pin-based) OAuth 1. Here is a command line example showing PIN-based authorization. |
|||
|
|||
export TWITTER_CONSUMER_KEY=xxx |
|||
export TWITTER_CONSUMER_SECRET=xxx |
|||
go run twitter-login.go |
|||
|
|||
The OAuth 1 flow can be used to implement Login with Twitter. Upon receiving an access token in a callback handler on your server, issue a user some form of unforgeable session identifier (i.e. cookie, token). Note that web backends should use a real `CallbackURL`, "oob" is for PIN-based agents such as the command line. |
|||
|
|||
### Authorized Requests |
|||
|
|||
Use the access `Token` to make requests on behalf of a Twitter user. |
|||
|
|||
export TWITTER_CONSUMER_KEY=xxx |
|||
export TWITTER_CONSUMER_SECRET=xxx |
|||
export TWITTER_ACCESS_TOKEN=xxx |
|||
export TWITTER_ACCESS_SECRET=xxx |
|||
go run twitter-request.go |
|||
|
|||
|
|||
## Tumblr |
|||
|
|||
### Authorization Flow |
|||
|
|||
An application can obtain a Tumblr access `Token` to act on behalf of a user. Here is a command line example which requests permission. |
|||
|
|||
export TUMBLR_CONSUMER_KEY=xxx |
|||
export TUMBLR_CONSUMER_SECRET=xxx |
|||
go run tumblr-login.go |
|||
|
|||
### Authorized Requests |
|||
|
|||
Use the access `Token` to make requests on behalf of a Tumblr user. |
|||
|
|||
export TUMBLR_CONSUMER_KEY=xxx |
|||
export TUMBLR_CONSUMER_SECRET=xxx |
|||
export TUMBLR_ACCESS_TOKEN=xxx |
|||
export TUMBLR_ACCESS_SECRET=xxx |
|||
go run tumblr-request.go |
|||
|
|||
Note that only some Tumblr endpoints require OAuth1 signed requests, other endpoints require a special consumer key query parameter or no authorization. |
|||
|
@ -0,0 +1,68 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
|
|||
"github.com/dghubble/oauth1" |
|||
"github.com/dghubble/oauth1/tumblr" |
|||
) |
|||
|
|||
var config oauth1.Config |
|||
|
|||
// main performs the Tumblr OAuth1 user flow from the command line
|
|||
func main() { |
|||
// read credentials from environment variables
|
|||
consumerKey := os.Getenv("TUMBLR_CONSUMER_KEY") |
|||
consumerSecret := os.Getenv("TUMBLR_CONSUMER_SECRET") |
|||
if consumerKey == "" || consumerSecret == "" { |
|||
log.Fatal("Required environment variable missing.") |
|||
} |
|||
|
|||
config = oauth1.Config{ |
|||
ConsumerKey: consumerKey, |
|||
ConsumerSecret: consumerSecret, |
|||
// Tumblr does not support oob, uses consumer registered callback
|
|||
CallbackURL: "", |
|||
Endpoint: tumblr.Endpoint, |
|||
} |
|||
|
|||
requestToken, requestSecret, err := login() |
|||
if err != nil { |
|||
log.Fatalf("Request Token Phase: %s", err.Error()) |
|||
} |
|||
accessToken, err := receivePIN(requestToken, requestSecret) |
|||
if err != nil { |
|||
log.Fatalf("Access Token Phase: %s", err.Error()) |
|||
} |
|||
|
|||
fmt.Println("Consumer was granted an access token to act on behalf of a user.") |
|||
fmt.Printf("token: %s\nsecret: %s\n", accessToken.Token, accessToken.TokenSecret) |
|||
} |
|||
|
|||
func login() (requestToken, requestSecret string, err error) { |
|||
requestToken, requestSecret, err = config.RequestToken() |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
authorizationURL, err := config.AuthorizationURL(requestToken) |
|||
if err != nil { |
|||
return "", "", err |
|||
} |
|||
fmt.Printf("Open this URL in your browser:\n%s\n", authorizationURL.String()) |
|||
return requestToken, requestSecret, err |
|||
} |
|||
|
|||
func receivePIN(requestToken, requestSecret string) (*oauth1.Token, error) { |
|||
fmt.Printf("Choose whether to grant the application access.\nPaste " + |
|||
"the oauth_verifier parameter (excluding trailing #_=_) from the " + |
|||
"address bar: ") |
|||
var verifier string |
|||
_, err := fmt.Scanf("%s", &verifier) |
|||
accessToken, accessSecret, err := config.AccessToken(requestToken, requestSecret, verifier) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return oauth1.NewToken(accessToken, accessSecret), err |
|||
} |
@ -0,0 +1,37 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"fmt" |
|||
"io/ioutil" |
|||
"os" |
|||
|
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
// Tumblr access token (token credential) requests on behalf of a user
|
|||
func main() { |
|||
// read credentials from environment variables
|
|||
consumerKey := os.Getenv("TUMBLR_CONSUMER_KEY") |
|||
consumerSecret := os.Getenv("TUMBLR_CONSUMER_SECRET") |
|||
accessToken := os.Getenv("TUMBLR_ACCESS_TOKEN") |
|||
accessSecret := os.Getenv("TUMBLR_ACCESS_SECRET") |
|||
if consumerKey == "" || consumerSecret == "" || accessToken == "" || accessSecret == "" { |
|||
panic("Missing required environment variable") |
|||
} |
|||
|
|||
config := oauth1.NewConfig(consumerKey, consumerSecret) |
|||
token := oauth1.NewToken(accessToken, accessSecret) |
|||
|
|||
// httpClient will automatically authorize http.Request's
|
|||
httpClient := config.Client(oauth1.NoContext, token) |
|||
|
|||
// get information about the current authenticated user
|
|||
path := "https://api.tumblr.com/v2/user/info" |
|||
resp, _ := httpClient.Get(path) |
|||
defer resp.Body.Close() |
|||
body, _ := ioutil.ReadAll(resp.Body) |
|||
fmt.Printf("Raw Response Body:\n%v\n", string(body)) |
|||
|
|||
// note: Tumblr requires OAuth signed requests for particular endpoints,
|
|||
// others just need a consumer key query parameter (its janky).
|
|||
} |
@ -0,0 +1,75 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
|
|||
"github.com/dghubble/oauth1" |
|||
twauth "github.com/dghubble/oauth1/twitter" |
|||
) |
|||
|
|||
const outOfBand = "oob" |
|||
|
|||
var config oauth1.Config |
|||
|
|||
// main performs Twitter PIN-based 3-legged OAuth 1 from the command line
|
|||
func main() { |
|||
// read credentials from environment variables
|
|||
consumerKey := os.Getenv("TWITTER_CONSUMER_KEY") |
|||
consumerSecret := os.Getenv("TWITTER_CONSUMER_SECRET") |
|||
if consumerKey == "" || consumerSecret == "" { |
|||
log.Fatal("Required environment variable missing.") |
|||
} |
|||
|
|||
config = oauth1.Config{ |
|||
ConsumerKey: consumerKey, |
|||
ConsumerSecret: consumerSecret, |
|||
CallbackURL: outOfBand, |
|||
Endpoint: twauth.AuthorizeEndpoint, |
|||
} |
|||
|
|||
requestToken, err := login() |
|||
if err != nil { |
|||
log.Fatalf("Request Token Phase: %s", err.Error()) |
|||
} |
|||
accessToken, err := receivePIN(requestToken) |
|||
if err != nil { |
|||
log.Fatalf("Access Token Phase: %s", err.Error()) |
|||
} |
|||
|
|||
fmt.Println("Consumer was granted an access token to act on behalf of a user.") |
|||
fmt.Printf("token: %s\nsecret: %s\n", accessToken.Token, accessToken.TokenSecret) |
|||
} |
|||
|
|||
func login() (requestToken string, err error) { |
|||
requestToken, _, err = config.RequestToken() |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
authorizationURL, err := config.AuthorizationURL(requestToken) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
fmt.Printf("Open this URL in your browser:\n%s\n", authorizationURL.String()) |
|||
return requestToken, err |
|||
} |
|||
|
|||
func receivePIN(requestToken string) (*oauth1.Token, error) { |
|||
fmt.Printf("Paste your PIN here: ") |
|||
var verifier string |
|||
_, err := fmt.Scanf("%s", &verifier) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
// Twitter ignores the oauth_signature on the access token request. The user
|
|||
// to which the request (temporary) token corresponds is already known on the
|
|||
// server. The request for a request token earlier was validated signed by
|
|||
// the consumer. Consumer applications can avoid keeping request token state
|
|||
// between authorization granting and callback handling.
|
|||
accessToken, accessSecret, err := config.AccessToken(requestToken, "secret does not matter", verifier) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return oauth1.NewToken(accessToken, accessSecret), err |
|||
} |
@ -0,0 +1,39 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"fmt" |
|||
"io/ioutil" |
|||
"os" |
|||
|
|||
"github.com/dghubble/go-twitter/twitter" |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
// Twitter user-auth requests with an Access Token (token credential)
|
|||
func main() { |
|||
// read credentials from environment variables
|
|||
consumerKey := os.Getenv("TWITTER_CONSUMER_KEY") |
|||
consumerSecret := os.Getenv("TWITTER_CONSUMER_SECRET") |
|||
accessToken := os.Getenv("TWITTER_ACCESS_TOKEN") |
|||
accessSecret := os.Getenv("TWITTER_ACCESS_SECRET") |
|||
if consumerKey == "" || consumerSecret == "" || accessToken == "" || accessSecret == "" { |
|||
panic("Missing required environment variable") |
|||
} |
|||
|
|||
config := oauth1.NewConfig(consumerKey, consumerSecret) |
|||
token := oauth1.NewToken(accessToken, accessSecret) |
|||
|
|||
// httpClient will automatically authorize http.Request's
|
|||
httpClient := config.Client(oauth1.NoContext, token) |
|||
|
|||
path := "https://api.twitter.com/1.1/statuses/home_timeline.json?count=2" |
|||
resp, _ := httpClient.Get(path) |
|||
defer resp.Body.Close() |
|||
body, _ := ioutil.ReadAll(resp.Body) |
|||
fmt.Printf("Raw Response Body:\n%v\n", string(body)) |
|||
|
|||
// Nicer: Pass OAuth1 client to go-twitter API
|
|||
api := twitter.NewClient(httpClient) |
|||
tweets, _, _ := api.Timelines.HomeTimeline(nil) |
|||
fmt.Printf("User's HOME TIMELINE:\n%+v\n", tweets) |
|||
} |
@ -0,0 +1,202 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
const ( |
|||
expectedVersion = "1.0" |
|||
expectedSignatureMethod = "HMAC-SHA1" |
|||
) |
|||
|
|||
func TestTwitterRequestTokenAuthHeader(t *testing.T) { |
|||
// example from https://dev.twitter.com/web/sign-in/implementing
|
|||
var unixTimestamp int64 = 1318467427 |
|||
expectedConsumerKey := "cChZNFj6T5R0TigYB9yd1w" |
|||
expectedCallback := "http%3A%2F%2Flocalhost%2Fsign-in-with-twitter%2F" |
|||
expectedSignature := "F1Li3tvehgcraF8DMJ7OyxO4w9Y%3D" |
|||
expectedTimestamp := "1318467427" |
|||
expectedNonce := "ea9ec8429b68d6b77cd5600adbbb0456" |
|||
config := &Config{ |
|||
ConsumerKey: expectedConsumerKey, |
|||
ConsumerSecret: "L8qq9PZyRg6ieKGEKhZolGC0vJWLw8iEJ88DRdyOg", |
|||
CallbackURL: "http://localhost/sign-in-with-twitter/", |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: "https://api.twitter.com/oauth/request_token", |
|||
AuthorizeURL: "https://api.twitter.com/oauth/authorize", |
|||
AccessTokenURL: "https://api.twitter.com/oauth/access_token", |
|||
}, |
|||
} |
|||
|
|||
auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}, &fixedNoncer{expectedNonce}} |
|||
req, err := http.NewRequest("POST", config.Endpoint.RequestTokenURL, nil) |
|||
assert.Nil(t, err) |
|||
err = auther.setRequestTokenAuthHeader(req) |
|||
// assert the request for a request token is signed and has an oauth_callback
|
|||
assert.Nil(t, err) |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) |
|||
assert.Equal(t, expectedCallback, params[oauthCallbackParam]) |
|||
assert.Equal(t, expectedSignature, params[oauthSignatureParam]) |
|||
// additional OAuth parameters
|
|||
assert.Equal(t, expectedConsumerKey, params[oauthConsumerKeyParam]) |
|||
assert.Equal(t, expectedNonce, params[oauthNonceParam]) |
|||
assert.Equal(t, expectedTimestamp, params[oauthTimestampParam]) |
|||
assert.Equal(t, expectedVersion, params[oauthVersionParam]) |
|||
assert.Equal(t, expectedSignatureMethod, params[oauthSignatureMethodParam]) |
|||
} |
|||
|
|||
func TestTwitterAccessTokenAuthHeader(t *testing.T) { |
|||
// example from https://dev.twitter.com/web/sign-in/implementing
|
|||
var unixTimestamp int64 = 1318467427 |
|||
expectedConsumerKey := "cChZNFj6T5R0TigYB9yd1w" |
|||
expectedRequestToken := "NPcudxy0yU5T3tBzho7iCotZ3cnetKwcTIRlX0iwRl0" |
|||
requestTokenSecret := "veNRnAWe6inFuo8o2u8SLLZLjolYDmDP7SzL0YfYI" |
|||
expectedVerifier := "uw7NjWHT6OJ1MpJOXsHfNxoAhPKpgI8BlYDhxEjIBY" |
|||
expectedSignature := "39cipBtIOHEEnybAR4sATQTpl2I%3D" |
|||
expectedTimestamp := "1318467427" |
|||
expectedNonce := "a9900fe68e2573b27a37f10fbad6a755" |
|||
config := &Config{ |
|||
ConsumerKey: expectedConsumerKey, |
|||
ConsumerSecret: "L8qq9PZyRg6ieKGEKhZolGC0vJWLw8iEJ88DRdyOg", |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: "https://api.twitter.com/oauth/request_token", |
|||
AuthorizeURL: "https://api.twitter.com/oauth/authorize", |
|||
AccessTokenURL: "https://api.twitter.com/oauth/access_token", |
|||
}, |
|||
} |
|||
|
|||
auther := &auther{config, &fixedClock{time.Unix(unixTimestamp, 0)}, &fixedNoncer{expectedNonce}} |
|||
req, err := http.NewRequest("POST", config.Endpoint.AccessTokenURL, nil) |
|||
assert.Nil(t, err) |
|||
err = auther.setAccessTokenAuthHeader(req, expectedRequestToken, requestTokenSecret, expectedVerifier) |
|||
// assert the request for an access token is signed and has an oauth_token and verifier
|
|||
assert.Nil(t, err) |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) |
|||
assert.Equal(t, expectedRequestToken, params[oauthTokenParam]) |
|||
assert.Equal(t, expectedVerifier, params[oauthVerifierParam]) |
|||
assert.Equal(t, expectedSignature, params[oauthSignatureParam]) |
|||
// additional OAuth parameters
|
|||
assert.Equal(t, expectedConsumerKey, params[oauthConsumerKeyParam]) |
|||
assert.Equal(t, expectedNonce, params[oauthNonceParam]) |
|||
assert.Equal(t, expectedTimestamp, params[oauthTimestampParam]) |
|||
assert.Equal(t, expectedVersion, params[oauthVersionParam]) |
|||
assert.Equal(t, expectedSignatureMethod, params[oauthSignatureMethodParam]) |
|||
} |
|||
|
|||
// example from https://dev.twitter.com/oauth/overview/authorizing-requests,
|
|||
// https://dev.twitter.com/oauth/overview/creating-signatures, and
|
|||
// https://dev.twitter.com/oauth/application-only
|
|||
var unixTimestampOfRequest int64 = 1318622958 |
|||
var expectedTwitterConsumerKey = "xvz1evFS4wEEPTGEFPHBog" |
|||
var expectedTwitterOAuthToken = "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb" |
|||
var expectedNonce = "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg" |
|||
var twitterConfig = &Config{ |
|||
ConsumerKey: expectedTwitterConsumerKey, |
|||
ConsumerSecret: "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw", |
|||
Endpoint: Endpoint{ |
|||
RequestTokenURL: "https://api.twitter.com/oauth/request_token", |
|||
AuthorizeURL: "https://api.twitter.com/oauth/authorize", |
|||
AccessTokenURL: "https://api.twitter.com/oauth/access_token", |
|||
}, |
|||
} |
|||
|
|||
func TestTwitterParameterString(t *testing.T) { |
|||
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}, &fixedNoncer{expectedNonce}} |
|||
values := url.Values{} |
|||
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") |
|||
// note: the reference example is old and uses api v1 in the URL
|
|||
req, err := http.NewRequest("post", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode())) |
|||
assert.Nil(t, err) |
|||
req.Header.Set(contentType, formContentType) |
|||
oauthParams := auther.commonOAuthParams() |
|||
oauthParams[oauthTokenParam] = expectedTwitterOAuthToken |
|||
params, err := collectParameters(req, oauthParams) |
|||
// assert that the parameter string matches the reference
|
|||
expectedParameterString := "include_entities=true&oauth_consumer_key=xvz1evFS4wEEPTGEFPHBog&oauth_nonce=kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg&oauth_signature_method=HMAC-SHA1&oauth_timestamp=1318622958&oauth_token=370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb&oauth_version=1.0&status=Hello%20Ladies%20%2B%20Gentlemen%2C%20a%20signed%20OAuth%20request%21" |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedParameterString, normalizedParameterString(params)) |
|||
} |
|||
|
|||
func TestTwitterSignatureBase(t *testing.T) { |
|||
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}, &fixedNoncer{expectedNonce}} |
|||
values := url.Values{} |
|||
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") |
|||
// note: the reference example is old and uses api v1 in the URL
|
|||
req, err := http.NewRequest("post", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode())) |
|||
assert.Nil(t, err) |
|||
req.Header.Set(contentType, formContentType) |
|||
oauthParams := auther.commonOAuthParams() |
|||
oauthParams[oauthTokenParam] = expectedTwitterOAuthToken |
|||
params, err := collectParameters(req, oauthParams) |
|||
signatureBase := signatureBase(req, params) |
|||
// assert that the signature base string matches the reference
|
|||
// checks that method is uppercased, url is encoded, parameter string is added, all joined by &
|
|||
expectedSignatureBase := "POST&https%3A%2F%2Fapi.twitter.com%2F1%2Fstatuses%2Fupdate.json&include_entities%3Dtrue%26oauth_consumer_key%3Dxvz1evFS4wEEPTGEFPHBog%26oauth_nonce%3DkYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg%26oauth_signature_method%3DHMAC-SHA1%26oauth_timestamp%3D1318622958%26oauth_token%3D370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb%26oauth_version%3D1.0%26status%3DHello%2520Ladies%2520%252B%2520Gentlemen%252C%2520a%2520signed%2520OAuth%2520request%2521" |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, expectedSignatureBase, signatureBase) |
|||
} |
|||
|
|||
func TestTwitterRequestAuthHeader(t *testing.T) { |
|||
oauthTokenSecret := "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE" |
|||
expectedSignature := PercentEncode("tnnArxj06cWHq44gCs1OSKk/jLY=") |
|||
expectedTimestamp := "1318622958" |
|||
|
|||
auther := &auther{twitterConfig, &fixedClock{time.Unix(unixTimestampOfRequest, 0)}, &fixedNoncer{expectedNonce}} |
|||
values := url.Values{} |
|||
values.Add("status", "Hello Ladies + Gentlemen, a signed OAuth request!") |
|||
|
|||
accessToken := &Token{expectedTwitterOAuthToken, oauthTokenSecret} |
|||
req, err := http.NewRequest("POST", "https://api.twitter.com/1/statuses/update.json?include_entities=true", strings.NewReader(values.Encode())) |
|||
assert.Nil(t, err) |
|||
req.Header.Set(contentType, formContentType) |
|||
err = auther.setRequestAuthHeader(req, accessToken) |
|||
// assert that request is signed and has an access token token
|
|||
assert.Nil(t, err) |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get(authorizationHeaderParam)) |
|||
assert.Equal(t, expectedTwitterOAuthToken, params[oauthTokenParam]) |
|||
assert.Equal(t, expectedSignature, params[oauthSignatureParam]) |
|||
// additional OAuth parameters
|
|||
assert.Equal(t, expectedTwitterConsumerKey, params[oauthConsumerKeyParam]) |
|||
assert.Equal(t, expectedNonce, params[oauthNonceParam]) |
|||
assert.Equal(t, expectedSignatureMethod, params[oauthSignatureMethodParam]) |
|||
assert.Equal(t, expectedTimestamp, params[oauthTimestampParam]) |
|||
assert.Equal(t, expectedVersion, params[oauthVersionParam]) |
|||
} |
|||
|
|||
func parseOAuthParamsOrFail(t *testing.T, authHeader string) map[string]string { |
|||
if !strings.HasPrefix(authHeader, authorizationPrefix) { |
|||
assert.Fail(t, fmt.Sprintf("Expected Authorization header to start with \"%s\", got \"%s\"", authorizationPrefix, authHeader[:len(authorizationPrefix)+1])) |
|||
} |
|||
params := map[string]string{} |
|||
for _, pairStr := range strings.Split(authHeader[len(authorizationPrefix):], ", ") { |
|||
pair := strings.Split(pairStr, "=") |
|||
if len(pair) != 2 { |
|||
assert.Fail(t, "Error parsing OAuth parameter %s", pairStr) |
|||
} |
|||
params[pair[0]] = strings.Replace(pair[1], "\"", "", -1) |
|||
} |
|||
return params |
|||
} |
|||
|
|||
type fixedClock struct { |
|||
now time.Time |
|||
} |
|||
|
|||
func (c *fixedClock) Now() time.Time { |
|||
return c.now |
|||
} |
|||
|
|||
type fixedNoncer struct { |
|||
nonce string |
|||
} |
|||
|
|||
func (n *fixedNoncer) Nonce() string { |
|||
return n.nonce |
|||
} |
@ -0,0 +1,62 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"crypto" |
|||
"crypto/hmac" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/sha1" |
|||
"encoding/base64" |
|||
"strings" |
|||
) |
|||
|
|||
// A Signer signs messages to create signed OAuth1 Requests.
|
|||
type Signer interface { |
|||
// Name returns the name of the signing method.
|
|||
Name() string |
|||
// Sign signs the message using the given secret key.
|
|||
Sign(key string, message string) (string, error) |
|||
} |
|||
|
|||
// HMACSigner signs messages with an HMAC SHA1 digest, using the concatenated
|
|||
// consumer secret and token secret as the key.
|
|||
type HMACSigner struct { |
|||
ConsumerSecret string |
|||
} |
|||
|
|||
// Name returns the HMAC-SHA1 method.
|
|||
func (s *HMACSigner) Name() string { |
|||
return "HMAC-SHA1" |
|||
} |
|||
|
|||
// Sign creates a concatenated consumer and token secret key and calculates
|
|||
// the HMAC digest of the message. Returns the base64 encoded digest bytes.
|
|||
func (s *HMACSigner) Sign(tokenSecret, message string) (string, error) { |
|||
signingKey := strings.Join([]string{s.ConsumerSecret, tokenSecret}, "&") |
|||
mac := hmac.New(sha1.New, []byte(signingKey)) |
|||
mac.Write([]byte(message)) |
|||
signatureBytes := mac.Sum(nil) |
|||
return base64.StdEncoding.EncodeToString(signatureBytes), nil |
|||
} |
|||
|
|||
// RSASigner RSA PKCS1-v1_5 signs SHA1 digests of messages using the given
|
|||
// RSA private key.
|
|||
type RSASigner struct { |
|||
PrivateKey *rsa.PrivateKey |
|||
} |
|||
|
|||
// Name returns the RSA-SHA1 method.
|
|||
func (s *RSASigner) Name() string { |
|||
return "RSA-SHA1" |
|||
} |
|||
|
|||
// Sign uses RSA PKCS1-v1_5 to sign a SHA1 digest of the given message. The
|
|||
// tokenSecret is not used with this signing scheme.
|
|||
func (s *RSASigner) Sign(tokenSecret, message string) (string, error) { |
|||
digest := sha1.Sum([]byte(message)) |
|||
signature, err := rsa.SignPKCS1v15(rand.Reader, s.PrivateKey, crypto.SHA1, digest[:]) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
return base64.StdEncoding.EncodeToString(signature), nil |
|||
} |
@ -0,0 +1,19 @@ |
|||
#!/bin/bash -e |
|||
|
|||
go test . -cover |
|||
go vet ./... |
|||
|
|||
echo "Checking gofmt..." |
|||
FORMATTABLE="$(find . -type f -name '*.go')" |
|||
fmtRes=$(gofmt -l $FORMATTABLE) |
|||
if [ -n "${fmtRes}" ]; then |
|||
echo -e "gofmt checking failed:\n${fmtRes}" |
|||
exit 2 |
|||
fi |
|||
|
|||
echo "Checking golint..." |
|||
lintRes=$(go list ./... | xargs -n 1 golint) |
|||
if [ -n "${lintRes}" ]; then |
|||
echo -e "golint checking failed:\n${lintRes}" |
|||
exit 2 |
|||
fi |
@ -0,0 +1,43 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"errors" |
|||
) |
|||
|
|||
// A TokenSource can return a Token.
|
|||
type TokenSource interface { |
|||
Token() (*Token, error) |
|||
} |
|||
|
|||
// Token is an AccessToken (token credential) which allows a consumer (client)
|
|||
// to access resources from an OAuth1 provider server.
|
|||
type Token struct { |
|||
Token string |
|||
TokenSecret string |
|||
} |
|||
|
|||
// NewToken returns a new Token with the given token and token secret.
|
|||
func NewToken(token, tokenSecret string) *Token { |
|||
return &Token{ |
|||
Token: token, |
|||
TokenSecret: tokenSecret, |
|||
} |
|||
} |
|||
|
|||
// StaticTokenSource returns a TokenSource which always returns the same Token.
|
|||
// This is appropriate for tokens which do not have a time expiration.
|
|||
func StaticTokenSource(token *Token) TokenSource { |
|||
return staticTokenSource{token} |
|||
} |
|||
|
|||
// staticTokenSource is a TokenSource that always returns the same Token.
|
|||
type staticTokenSource struct { |
|||
token *Token |
|||
} |
|||
|
|||
func (s staticTokenSource) Token() (*Token, error) { |
|||
if s.token == nil { |
|||
return nil, errors.New("oauth1: Token is nil") |
|||
} |
|||
return s.token, nil |
|||
} |
@ -0,0 +1,31 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestNewToken(t *testing.T) { |
|||
expectedToken := "token" |
|||
expectedSecret := "secret" |
|||
tk := NewToken(expectedToken, expectedSecret) |
|||
assert.Equal(t, expectedToken, tk.Token) |
|||
assert.Equal(t, expectedSecret, tk.TokenSecret) |
|||
} |
|||
|
|||
func TestStaticTokenSource(t *testing.T) { |
|||
ts := StaticTokenSource(NewToken("t", "s")) |
|||
tk, err := ts.Token() |
|||
assert.Nil(t, err) |
|||
assert.Equal(t, "t", tk.Token) |
|||
} |
|||
|
|||
func TestStaticTokenSourceEmpty(t *testing.T) { |
|||
ts := StaticTokenSource(nil) |
|||
tk, err := ts.Token() |
|||
assert.Nil(t, tk) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "oauth1: Token is nil", err.Error()) |
|||
} |
|||
} |
@ -0,0 +1,65 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net/http" |
|||
) |
|||
|
|||
// Transport is an http.RoundTripper which makes OAuth1 HTTP requests. It
|
|||
// wraps a base RoundTripper and adds an Authorization header using the
|
|||
// token from a TokenSource.
|
|||
//
|
|||
// Transport is a low-level component, most users should use Config to create
|
|||
// an http.Client instead.
|
|||
type Transport struct { |
|||
// Base is the base RoundTripper used to make HTTP requests. If nil, then
|
|||
// http.DefaultTransport is used
|
|||
Base http.RoundTripper |
|||
// source supplies the token to use when signing a request
|
|||
source TokenSource |
|||
// auther adds OAuth1 Authorization headers to requests
|
|||
auther *auther |
|||
} |
|||
|
|||
// RoundTrip authorizes the request with a signed OAuth1 Authorization header
|
|||
// using the auther and TokenSource.
|
|||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { |
|||
if t.source == nil { |
|||
return nil, fmt.Errorf("oauth1: Transport's source is nil") |
|||
} |
|||
accessToken, err := t.source.Token() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if t.auther == nil { |
|||
return nil, fmt.Errorf("oauth1: Transport's auther is nil") |
|||
} |
|||
// RoundTripper should not modify the given request, clone it
|
|||
req2 := cloneRequest(req) |
|||
err = t.auther.setRequestAuthHeader(req2, accessToken) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return t.base().RoundTrip(req2) |
|||
} |
|||
|
|||
func (t *Transport) base() http.RoundTripper { |
|||
if t.Base != nil { |
|||
return t.Base |
|||
} |
|||
return http.DefaultTransport |
|||
} |
|||
|
|||
// cloneRequest returns a clone of the given *http.Request with a shallow
|
|||
// copy of struct fields and a deep copy of the Header map.
|
|||
func cloneRequest(req *http.Request) *http.Request { |
|||
// shallow copy the struct
|
|||
r2 := new(http.Request) |
|||
*r2 = *req |
|||
// deep copy Header so setting a header on the clone does not affect original
|
|||
r2.Header = make(http.Header, len(req.Header)) |
|||
for k, s := range req.Header { |
|||
r2.Header[k] = append([]string(nil), s...) |
|||
} |
|||
return r2 |
|||
} |
@ -0,0 +1,117 @@ |
|||
package oauth1 |
|||
|
|||
import ( |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestTransport(t *testing.T) { |
|||
const ( |
|||
expectedToken = "access_token" |
|||
expectedConsumerKey = "consumer_key" |
|||
expectedNonce = "some_nonce" |
|||
expectedSignatureMethod = "HMAC-SHA1" |
|||
expectedTimestamp = "123456789" |
|||
) |
|||
server := newMockServer(func(w http.ResponseWriter, req *http.Request) { |
|||
params := parseOAuthParamsOrFail(t, req.Header.Get("Authorization")) |
|||
assert.Equal(t, expectedToken, params[oauthTokenParam]) |
|||
assert.Equal(t, expectedConsumerKey, params[oauthConsumerKeyParam]) |
|||
assert.Equal(t, expectedNonce, params[oauthNonceParam]) |
|||
assert.Equal(t, expectedSignatureMethod, params[oauthSignatureMethodParam]) |
|||
assert.Equal(t, expectedTimestamp, params[oauthTimestampParam]) |
|||
assert.Equal(t, defaultOauthVersion, params[oauthVersionParam]) |
|||
// oauth_signature will vary, httptest.Server uses a random port
|
|||
}) |
|||
defer server.Close() |
|||
|
|||
config := &Config{ |
|||
ConsumerKey: expectedConsumerKey, |
|||
ConsumerSecret: "consumer_secret", |
|||
} |
|||
auther := &auther{ |
|||
config: config, |
|||
clock: &fixedClock{time.Unix(123456789, 0)}, |
|||
noncer: &fixedNoncer{expectedNonce}, |
|||
} |
|||
tr := &Transport{ |
|||
source: StaticTokenSource(NewToken(expectedToken, "some_secret")), |
|||
auther: auther, |
|||
} |
|||
client := &http.Client{Transport: tr} |
|||
|
|||
req, err := http.NewRequest("GET", server.URL, nil) |
|||
assert.Nil(t, err) |
|||
_, err = client.Do(req) |
|||
assert.Nil(t, err) |
|||
} |
|||
|
|||
func TestTransport_defaultBaseTransport(t *testing.T) { |
|||
tr := &Transport{ |
|||
Base: nil, |
|||
} |
|||
assert.Equal(t, http.DefaultTransport, tr.base()) |
|||
} |
|||
|
|||
func TestTransport_customBaseTransport(t *testing.T) { |
|||
expected := &http.Transport{} |
|||
tr := &Transport{ |
|||
Base: expected, |
|||
} |
|||
assert.Equal(t, expected, tr.base()) |
|||
} |
|||
|
|||
func TestTransport_nilSource(t *testing.T) { |
|||
tr := &Transport{ |
|||
source: nil, |
|||
auther: &auther{ |
|||
config: &Config{}, |
|||
clock: &fixedClock{time.Unix(123456789, 0)}, |
|||
noncer: &fixedNoncer{"any_nonce"}, |
|||
}, |
|||
} |
|||
client := &http.Client{Transport: tr} |
|||
resp, err := client.Get("http://example.com") |
|||
assert.Nil(t, resp) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "Get http://example.com: oauth1: Transport's source is nil", err.Error()) |
|||
} |
|||
} |
|||
|
|||
func TestTransport_emptySource(t *testing.T) { |
|||
tr := &Transport{ |
|||
source: StaticTokenSource(nil), |
|||
auther: &auther{ |
|||
config: &Config{}, |
|||
clock: &fixedClock{time.Unix(123456789, 0)}, |
|||
noncer: &fixedNoncer{"any_nonce"}, |
|||
}, |
|||
} |
|||
client := &http.Client{Transport: tr} |
|||
resp, err := client.Get("http://example.com") |
|||
assert.Nil(t, resp) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "Get http://example.com: oauth1: Token is nil", err.Error()) |
|||
} |
|||
} |
|||
|
|||
func TestTransport_nilAuther(t *testing.T) { |
|||
tr := &Transport{ |
|||
source: StaticTokenSource(&Token{}), |
|||
auther: nil, |
|||
} |
|||
client := &http.Client{Transport: tr} |
|||
resp, err := client.Get("http://example.com") |
|||
assert.Nil(t, resp) |
|||
if assert.Error(t, err) { |
|||
assert.Equal(t, "Get http://example.com: oauth1: Transport's auther is nil", err.Error()) |
|||
} |
|||
} |
|||
|
|||
func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { |
|||
return httptest.NewServer(http.HandlerFunc(handler)) |
|||
} |
@ -0,0 +1,13 @@ |
|||
// Package tumblr provides constants for using OAuth 1 to access Tumblr.
|
|||
package tumblr |
|||
|
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
// Endpoint is Tumblr's OAuth 1a endpoint.
|
|||
var Endpoint = oauth1.Endpoint{ |
|||
RequestTokenURL: "http://www.tumblr.com/oauth/request_token", |
|||
AuthorizeURL: "http://www.tumblr.com/oauth/authorize", |
|||
AccessTokenURL: "http://www.tumblr.com/oauth/access_token", |
|||
} |
@ -0,0 +1,25 @@ |
|||
// Package twitter provides constants for using OAuth1 to access Twitter.
|
|||
package twitter |
|||
|
|||
import ( |
|||
"github.com/dghubble/oauth1" |
|||
) |
|||
|
|||
// AuthenticateEndpoint is Twitter's OAuth 1 endpoint which uses the
|
|||
// oauth/authenticate AuthorizeURL redirect. Logged in users who have granted
|
|||
// access are immediately authenticated and redirected to the callback URL.
|
|||
var AuthenticateEndpoint = oauth1.Endpoint{ |
|||
RequestTokenURL: "https://api.twitter.com/oauth/request_token", |
|||
AuthorizeURL: "https://api.twitter.com/oauth/authenticate", |
|||
AccessTokenURL: "https://api.twitter.com/oauth/access_token", |
|||
} |
|||
|
|||
// AuthorizeEndpoint is Twitter's OAuth 1 endpoint which uses the
|
|||
// oauth/authorize AuthorizeURL redirect. Note that this requires users who
|
|||
// have granted access previously, to re-grant access at AuthorizeURL.
|
|||
// Prefer AuthenticateEndpoint over AuthorizeEndpoint if you are unsure.
|
|||
var AuthorizeEndpoint = oauth1.Endpoint{ |
|||
RequestTokenURL: "https://api.twitter.com/oauth/request_token", |
|||
AuthorizeURL: "https://api.twitter.com/oauth/authorize", |
|||
AccessTokenURL: "https://api.twitter.com/oauth/access_token", |
|||
} |
@ -0,0 +1,52 @@ |
|||
# Sling Changelog |
|||
|
|||
Notable changes between releases. |
|||
|
|||
## latest |
|||
|
|||
* Added Sling `Body` setter to set an `io.Reader` on the Request |
|||
|
|||
## v1.0.0 (2015-05-23) |
|||
|
|||
* Added support for receiving and decoding error JSON structs |
|||
* Renamed Sling `JsonBody` setter to `BodyJSON` (breaking) |
|||
* Renamed Sling `BodyStruct` setter to `BodyForm` (breaking) |
|||
* Renamed Sling fields `httpClient`, `method`, `rawURL`, and `header` to be internal (breaking) |
|||
* Changed `Do` and `Receive` to skip response JSON decoding if "application/json" Content-Type is missing |
|||
* Changed `Sling.Receive(v interface{})` to `Sling.Receive(successV, failureV interface{})` (breaking) |
|||
* Previously `Receive` attempted to decode the response Body in all cases |
|||
* Updated `Receive` will decode the response Body into successV for 2XX responses or decode the Body into failureV for other status codes. Pass a nil `successV` or `failureV` to skip JSON decoding into that value. |
|||
* To upgrade, pass nil for the `failureV` argument or consider defining a JSON tagged struct appropriate for the API endpoint. (e.g. `s.Receive(&issue, nil)`, `s.Receive(&issue, &githubError)`) |
|||
* To retain the old behavior, duplicate the first argument (e.g. s.Receive(&tweet, &tweet)) |
|||
* Changed `Sling.Do(http.Request, v interface{})` to `Sling.Do(http.Request, successV, failureV interface{})` (breaking) |
|||
* See the changelog entry about `Receive`, the upgrade path is the same. |
|||
* Removed HEAD, GET, POST, PUT, PATCH, DELETE constants, no reason to export them (breaking) |
|||
|
|||
## v0.4.0 (2015-04-26) |
|||
|
|||
* Improved golint compliance |
|||
* Fixed typos and test printouts |
|||
|
|||
## v0.3.0 (2015-04-21) |
|||
|
|||
* Added BodyStruct method for setting a url encoded form body on the Request |
|||
* Added Add and Set methods for adding or setting Request Headers |
|||
* Added JsonBody method for setting JSON Request Body |
|||
* Improved examples and documentation |
|||
|
|||
## v0.2.0 (2015-04-05) |
|||
|
|||
* Added http.Client setter |
|||
* Added Sling.New() method to return a copy of a Sling |
|||
* Added Base setter and Path extension support |
|||
* Added method setters (Get, Post, Put, Patch, Delete, Head) |
|||
* Added support for encoding URL Query parameters |
|||
* Added example tiny Github API |
|||
* Changed v0.1.0 method signatures and names (breaking) |
|||
* Removed Go 1.0 support |
|||
|
|||
## v0.1.0 (2015-04-01) |
|||
|
|||
* Support decoding JSON responses. |
|||
|
|||
|
@ -0,0 +1,21 @@ |
|||
The MIT License (MIT) |
|||
|
|||
Copyright (c) 2015 Dalton Hubble |
|||
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy |
|||
of this software and associated documentation files (the "Software"), to deal |
|||
in the Software without restriction, including without limitation the rights |
|||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|||
copies of the Software, and to permit persons to whom the Software is |
|||
furnished to do so, subject to the following conditions: |
|||
|
|||
The above copyright notice and this permission notice shall be included in |
|||
all copies or substantial portions of the Software. |
|||
|
|||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
|||
THE SOFTWARE. |
@ -0,0 +1,273 @@ |
|||
|
|||
# Sling [![Build Status](https://travis-ci.org/dghubble/sling.png?branch=master)](https://travis-ci.org/dghubble/sling) [![GoDoc](https://godoc.org/github.com/dghubble/sling?status.png)](https://godoc.org/github.com/dghubble/sling) |
|||
<img align="right" src="https://s3.amazonaws.com/dghubble/small-gopher-with-sling.png"> |
|||
|
|||
Sling is a Go HTTP client library for creating and sending API requests. |
|||
|
|||
Slings store HTTP Request properties to simplify sending requests and decoding responses. Check [usage](#usage) or the [examples](examples) to learn how to compose a Sling into your API client. |
|||
|
|||
### Features |
|||
|
|||
* Method Setters: Get/Post/Put/Patch/Delete/Head |
|||
* Add or Set Request Headers |
|||
* Base/Path: Extend a Sling for different endpoints |
|||
* Encode structs into URL query parameters |
|||
* Encode a form or JSON into the Request Body |
|||
* Receive JSON success or failure responses |
|||
|
|||
## Install |
|||
|
|||
go get github.com/dghubble/sling |
|||
|
|||
## Documentation |
|||
|
|||
Read [GoDoc](https://godoc.org/github.com/dghubble/sling) |
|||
|
|||
## Usage |
|||
|
|||
Use a Sling to set path, method, header, query, or body properties and create an `http.Request`. |
|||
|
|||
```go |
|||
type Params struct { |
|||
Count int `url:"count,omitempty"` |
|||
} |
|||
params := &Params{Count: 5} |
|||
|
|||
req, err := sling.New().Get("https://example.com").QueryStruct(params).Request() |
|||
client.Do(req) |
|||
``` |
|||
|
|||
### Path |
|||
|
|||
Use `Path` to set or extend the URL for created Requests. Extension means the path will be resolved relative to the existing URL. |
|||
|
|||
```go |
|||
// creates a GET request to https://example.com/foo/bar |
|||
req, err := sling.New().Base("https://example.com/").Path("foo/").Path("bar").Request() |
|||
``` |
|||
|
|||
Use `Get`, `Post`, `Put`, `Patch`, `Delete`, or `Head` which are exactly the same as `Path` except they set the HTTP method too. |
|||
|
|||
```go |
|||
req, err := sling.New().Post("http://upload.com/gophers") |
|||
``` |
|||
|
|||
### Headers |
|||
|
|||
`Add` or `Set` headers for requests created by a Sling. |
|||
|
|||
```go |
|||
s := sling.New().Base(baseUrl).Set("User-Agent", "Gophergram API Client") |
|||
req, err := s.New().Get("gophergram/list").Request() |
|||
``` |
|||
|
|||
### Query |
|||
|
|||
#### QueryStruct |
|||
|
|||
Define [url tagged structs](https://godoc.org/github.com/google/go-querystring/query). Use `QueryStruct` to encode a struct as query parameters on requests. |
|||
|
|||
```go |
|||
// Github Issue Parameters |
|||
type IssueParams struct { |
|||
Filter string `url:"filter,omitempty"` |
|||
State string `url:"state,omitempty"` |
|||
Labels string `url:"labels,omitempty"` |
|||
Sort string `url:"sort,omitempty"` |
|||
Direction string `url:"direction,omitempty"` |
|||
Since string `url:"since,omitempty"` |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
githubBase := sling.New().Base("https://api.github.com/").Client(httpClient) |
|||
|
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
params := &IssueParams{Sort: "updated", State: "open"} |
|||
req, err := githubBase.New().Get(path).QueryStruct(params).Request() |
|||
``` |
|||
|
|||
### Body |
|||
|
|||
#### JSON Body |
|||
|
|||
Define [JSON tagged structs](https://golang.org/pkg/encoding/json/). Use `BodyJSON` to JSON encode a struct as the Body on requests. |
|||
|
|||
```go |
|||
type IssueRequest struct { |
|||
Title string `json:"title,omitempty"` |
|||
Body string `json:"body,omitempty"` |
|||
Assignee string `json:"assignee,omitempty"` |
|||
Milestone int `json:"milestone,omitempty"` |
|||
Labels []string `json:"labels,omitempty"` |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
githubBase := sling.New().Base("https://api.github.com/").Client(httpClient) |
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
|
|||
body := &IssueRequest{ |
|||
Title: "Test title", |
|||
Body: "Some issue", |
|||
} |
|||
req, err := githubBase.New().Post(path).BodyJSON(body).Request() |
|||
``` |
|||
|
|||
Requests will include an `application/json` Content-Type header. |
|||
|
|||
#### Form Body |
|||
|
|||
Define [url tagged structs](https://godoc.org/github.com/google/go-querystring/query). Use `BodyForm` to form url encode a struct as the Body on requests. |
|||
|
|||
```go |
|||
type StatusUpdateParams struct { |
|||
Status string `url:"status,omitempty"` |
|||
InReplyToStatusId int64 `url:"in_reply_to_status_id,omitempty"` |
|||
MediaIds []int64 `url:"media_ids,omitempty,comma"` |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
tweetParams := &StatusUpdateParams{Status: "writing some Go"} |
|||
req, err := twitterBase.New().Post(path).BodyForm(tweetParams).Request() |
|||
``` |
|||
|
|||
Requests will include an `application/x-www-form-urlencoded` Content-Type header. |
|||
|
|||
#### Plain Body |
|||
|
|||
Use `Body` to set a plain `io.Reader` on requests created by a Sling. |
|||
|
|||
```go |
|||
body := strings.NewReader("raw body") |
|||
req, err := sling.New().Base("https://example.com").Body(body).Request() |
|||
``` |
|||
|
|||
Set a content type header, if desired (e.g. `Set("Content-Type", "text/plain")`). |
|||
|
|||
### Extend a Sling |
|||
|
|||
Each Sling creates a standard `http.Request` (e.g. with some path and query |
|||
params) each time `Request()` is called. You may wish to extend an existing Sling to minimize duplication (e.g. a common client or base url). |
|||
|
|||
Each Sling instance provides a `New()` method which creates an independent copy, so setting properties on the child won't mutate the parent Sling. |
|||
|
|||
```go |
|||
const twitterApi = "https://api.twitter.com/1.1/" |
|||
base := sling.New().Base(twitterApi).Client(authClient) |
|||
|
|||
// statuses/show.json Sling |
|||
tweetShowSling := base.New().Get("statuses/show.json").QueryStruct(params) |
|||
req, err := tweetShowSling.Request() |
|||
|
|||
// statuses/update.json Sling |
|||
tweetPostSling := base.New().Post("statuses/update.json").BodyForm(params) |
|||
req, err := tweetPostSling.Request() |
|||
``` |
|||
|
|||
Without the calls to `base.New()`, `tweetShowSling` and `tweetPostSling` would reference the base Sling and POST to |
|||
"https://api.twitter.com/1.1/statuses/show.json/statuses/update.json", which |
|||
is undesired. |
|||
|
|||
Recap: If you wish to *extend* a Sling, create a new child copy with `New()`. |
|||
|
|||
### Sending |
|||
|
|||
#### Receive |
|||
|
|||
Define a JSON struct to decode a type from 2XX success responses. Use `ReceiveSuccess(successV interface{})` to send a new Request and decode the response body into `successV` if it succeeds. |
|||
|
|||
```go |
|||
// Github Issue (abbreviated) |
|||
type Issue struct { |
|||
Title string `json:"title"` |
|||
Body string `json:"body"` |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
issues := new([]Issue) |
|||
resp, err := githubBase.New().Get(path).QueryStruct(params).ReceiveSuccess(issues) |
|||
fmt.Println(issues, resp, err) |
|||
``` |
|||
|
|||
Most APIs return failure responses with JSON error details. To decode these, define success and failure JSON structs. Use `Receive(successV, failureV interface{})` to send a new Request that will automatically decode the response into the `successV` for 2XX responses or into `failureV` for non-2XX responses. |
|||
|
|||
```go |
|||
type GithubError struct { |
|||
Message string `json:"message"` |
|||
Errors []struct { |
|||
Resource string `json:"resource"` |
|||
Field string `json:"field"` |
|||
Code string `json:"code"` |
|||
} `json:"errors"` |
|||
DocumentationURL string `json:"documentation_url"` |
|||
} |
|||
``` |
|||
|
|||
```go |
|||
issues := new([]Issue) |
|||
githubError := new(GithubError) |
|||
resp, err := githubBase.New().Get(path).QueryStruct(params).Receive(issues, githubError) |
|||
fmt.Println(issues, githubError, resp, err) |
|||
``` |
|||
|
|||
Pass a nil `successV` or `failureV` argument to skip JSON decoding into that value. |
|||
|
|||
### Build an API |
|||
|
|||
APIs typically define an endpoint (also called a service) for each type of resource. For example, here is a tiny Github IssueService which [lists](https://developer.github.com/v3/issues/#list-issues-for-a-repository) repository issues. |
|||
|
|||
```go |
|||
const baseURL = "https://api.github.com/" |
|||
|
|||
type IssueService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
func NewIssueService(httpClient *http.Client) *IssueService { |
|||
return &IssueService{ |
|||
sling: sling.New().Client(httpClient).Base(baseURL), |
|||
} |
|||
} |
|||
|
|||
func (s *IssueService) ListByRepo(owner, repo string, params *IssueListParams) ([]Issue, *http.Response, error) { |
|||
issues := new([]Issue) |
|||
githubError := new(GithubError) |
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
resp, err := s.sling.New().Get(path).QueryStruct(params).Receive(issues, githubError) |
|||
if err == nil { |
|||
err = githubError |
|||
} |
|||
return *issues, resp, err |
|||
} |
|||
``` |
|||
|
|||
## Example APIs using Sling |
|||
|
|||
* Digits [dghubble/go-digits](https://github.com/dghubble/go-digits) |
|||
* GoSquared [drinkin/go-gosquared](https://github.com/drinkin/go-gosquared) |
|||
* Kala [ajvb/kala](https://github.com/ajvb/kala) |
|||
* Parse [fergstar/go-parse](https://github.com/fergstar/go-parse) |
|||
* Rdio [apriendeau/shares](https://github.com/apriendeau/shares) |
|||
* Swagger Generator [swagger-api/swagger-codegen](https://github.com/swagger-api/swagger-codegen) |
|||
* Twitter [dghubble/go-twitter](https://github.com/dghubble/go-twitter) |
|||
* Hacker News [mirceamironenco/go-hackernews](https://github.com/mirceamironenco/go-hackernews) |
|||
* Stacksmith [jesustinoco/go-smith](https://github.com/jesustinoco/go-smith) |
|||
|
|||
Create a Pull Request to add a link to your own API. |
|||
|
|||
## Motivation |
|||
|
|||
Many client libraries follow the lead of [google/go-github](https://github.com/google/go-github) (our inspiration!), but do so by reimplementing logic common to all clients. |
|||
|
|||
This project borrows and abstracts those ideas into a Sling, an agnostic component any API client can use for creating and sending requests. |
|||
|
|||
## Contributing |
|||
|
|||
See the [Contributing Guide](https://gist.github.com/dghubble/be682c123727f70bcfe7). |
|||
|
|||
## License |
|||
|
|||
[MIT License](LICENSE) |
@ -0,0 +1,179 @@ |
|||
/* |
|||
Package sling is a Go HTTP client library for creating and sending API requests. |
|||
|
|||
Slings store HTTP Request properties to simplify sending requests and decoding |
|||
responses. Check the examples to learn how to compose a Sling into your API |
|||
client. |
|||
|
|||
Usage |
|||
|
|||
Use a Sling to set path, method, header, query, or body properties and create an |
|||
http.Request. |
|||
|
|||
type Params struct { |
|||
Count int `url:"count,omitempty"` |
|||
} |
|||
params := &Params{Count: 5} |
|||
|
|||
req, err := sling.New().Get("https://example.com").QueryStruct(params).Request() |
|||
client.Do(req) |
|||
|
|||
Path |
|||
|
|||
Use Path to set or extend the URL for created Requests. Extension means the |
|||
path will be resolved relative to the existing URL. |
|||
|
|||
// creates a GET request to https://example.com/foo/bar
|
|||
req, err := sling.New().Base("https://example.com/").Path("foo/").Path("bar").Request() |
|||
|
|||
Use Get, Post, Put, Patch, Delete, or Head which are exactly the same as Path |
|||
except they set the HTTP method too. |
|||
|
|||
req, err := sling.New().Post("http://upload.com/gophers") |
|||
|
|||
Headers |
|||
|
|||
Add or Set headers for requests created by a Sling. |
|||
|
|||
s := sling.New().Base(baseUrl).Set("User-Agent", "Gophergram API Client") |
|||
req, err := s.New().Get("gophergram/list").Request() |
|||
|
|||
QueryStruct |
|||
|
|||
Define url parameter structs (https://godoc.org/github.com/google/go-querystring/query).
|
|||
Use QueryStruct to encode a struct as query parameters on requests. |
|||
|
|||
// Github Issue Parameters
|
|||
type IssueParams struct { |
|||
Filter string `url:"filter,omitempty"` |
|||
State string `url:"state,omitempty"` |
|||
Labels string `url:"labels,omitempty"` |
|||
Sort string `url:"sort,omitempty"` |
|||
Direction string `url:"direction,omitempty"` |
|||
Since string `url:"since,omitempty"` |
|||
} |
|||
|
|||
githubBase := sling.New().Base("https://api.github.com/").Client(httpClient) |
|||
|
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
params := &IssueParams{Sort: "updated", State: "open"} |
|||
req, err := githubBase.New().Get(path).QueryStruct(params).Request() |
|||
|
|||
Json Body |
|||
|
|||
Define JSON tagged structs (https://golang.org/pkg/encoding/json/).
|
|||
Use BodyJSON to JSON encode a struct as the Body on requests. |
|||
|
|||
type IssueRequest struct { |
|||
Title string `json:"title,omitempty"` |
|||
Body string `json:"body,omitempty"` |
|||
Assignee string `json:"assignee,omitempty"` |
|||
Milestone int `json:"milestone,omitempty"` |
|||
Labels []string `json:"labels,omitempty"` |
|||
} |
|||
|
|||
githubBase := sling.New().Base("https://api.github.com/").Client(httpClient) |
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
|
|||
body := &IssueRequest{ |
|||
Title: "Test title", |
|||
Body: "Some issue", |
|||
} |
|||
req, err := githubBase.New().Post(path).BodyJSON(body).Request() |
|||
|
|||
Requests will include an "application/json" Content-Type header. |
|||
|
|||
Form Body |
|||
|
|||
Define url tagged structs (https://godoc.org/github.com/google/go-querystring/query).
|
|||
Use BodyForm to form url encode a struct as the Body on requests. |
|||
|
|||
type StatusUpdateParams struct { |
|||
Status string `url:"status,omitempty"` |
|||
InReplyToStatusId int64 `url:"in_reply_to_status_id,omitempty"` |
|||
MediaIds []int64 `url:"media_ids,omitempty,comma"` |
|||
} |
|||
|
|||
tweetParams := &StatusUpdateParams{Status: "writing some Go"} |
|||
req, err := twitterBase.New().Post(path).BodyForm(tweetParams).Request() |
|||
|
|||
Requests will include an "application/x-www-form-urlencoded" Content-Type |
|||
header. |
|||
|
|||
Plain Body |
|||
|
|||
Use Body to set a plain io.Reader on requests created by a Sling. |
|||
|
|||
body := strings.NewReader("raw body") |
|||
req, err := sling.New().Base("https://example.com").Body(body).Request() |
|||
|
|||
Set a content type header, if desired (e.g. Set("Content-Type", "text/plain")). |
|||
|
|||
Extend a Sling |
|||
|
|||
Each Sling generates an http.Request (say with some path and query params) |
|||
each time Request() is called, based on its state. When creating |
|||
different slings, you may wish to extend an existing Sling to minimize |
|||
duplication (e.g. a common client). |
|||
|
|||
Each Sling instance provides a New() method which creates an independent copy, |
|||
so setting properties on the child won't mutate the parent Sling. |
|||
|
|||
const twitterApi = "https://api.twitter.com/1.1/" |
|||
base := sling.New().Base(twitterApi).Client(authClient) |
|||
|
|||
// statuses/show.json Sling
|
|||
tweetShowSling := base.New().Get("statuses/show.json").QueryStruct(params) |
|||
req, err := tweetShowSling.Request() |
|||
|
|||
// statuses/update.json Sling
|
|||
tweetPostSling := base.New().Post("statuses/update.json").BodyForm(params) |
|||
req, err := tweetPostSling.Request() |
|||
|
|||
Without the calls to base.New(), tweetShowSling and tweetPostSling would |
|||
reference the base Sling and POST to |
|||
"https://api.twitter.com/1.1/statuses/show.json/statuses/update.json", which |
|||
is undesired. |
|||
|
|||
Recap: If you wish to extend a Sling, create a new child copy with New(). |
|||
|
|||
Receive |
|||
|
|||
Define a JSON struct to decode a type from 2XX success responses. Use |
|||
ReceiveSuccess(successV interface{}) to send a new Request and decode the |
|||
response body into successV if it succeeds. |
|||
|
|||
// Github Issue (abbreviated)
|
|||
type Issue struct { |
|||
Title string `json:"title"` |
|||
Body string `json:"body"` |
|||
} |
|||
|
|||
issues := new([]Issue) |
|||
resp, err := githubBase.New().Get(path).QueryStruct(params).ReceiveSuccess(issues) |
|||
fmt.Println(issues, resp, err) |
|||
|
|||
Most APIs return failure responses with JSON error details. To decode these, |
|||
define success and failure JSON structs. Use |
|||
Receive(successV, failureV interface{}) to send a new Request that will |
|||
automatically decode the response into the successV for 2XX responses or into |
|||
failureV for non-2XX responses. |
|||
|
|||
type GithubError struct { |
|||
Message string `json:"message"` |
|||
Errors []struct { |
|||
Resource string `json:"resource"` |
|||
Field string `json:"field"` |
|||
Code string `json:"code"` |
|||
} `json:"errors"` |
|||
DocumentationURL string `json:"documentation_url"` |
|||
} |
|||
|
|||
issues := new([]Issue) |
|||
githubError := new(GithubError) |
|||
resp, err := githubBase.New().Get(path).QueryStruct(params).Receive(issues, githubError) |
|||
fmt.Println(issues, githubError, resp, err) |
|||
|
|||
Pass a nil successV or failureV argument to skip JSON decoding into that value. |
|||
*/ |
|||
package sling |
@ -0,0 +1,19 @@ |
|||
|
|||
## Example API Client with Sling |
|||
|
|||
Try the example Github API Client. |
|||
|
|||
cd examples |
|||
go get . |
|||
|
|||
List the public issues on the [github.com/golang/go](https://github.com/golang/go) repository. |
|||
|
|||
go run github.go |
|||
|
|||
To list your public and private Github issues, pass your [Github Access Token](https://github.com/settings/tokens) |
|||
|
|||
go run github.go -access-token=xxx |
|||
|
|||
or set the `GITHUB_ACCESS_TOKEN` environment variable. |
|||
|
|||
For a complete Github API, see the excellent [google/go-github](https://github.com/google/go-github) package. |
@ -0,0 +1,161 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"flag" |
|||
"fmt" |
|||
"log" |
|||
"net/http" |
|||
"os" |
|||
|
|||
"github.com/coreos/pkg/flagutil" |
|||
"github.com/dghubble/sling" |
|||
"golang.org/x/oauth2" |
|||
) |
|||
|
|||
const baseURL = "https://api.github.com/" |
|||
|
|||
// Issue is a simplified Github issue
|
|||
// https://developer.github.com/v3/issues/#response
|
|||
type Issue struct { |
|||
ID int `json:"id"` |
|||
URL string `json:"url"` |
|||
Number int `json:"number"` |
|||
State string `json:"state"` |
|||
Title string `json:"title"` |
|||
Body string `json:"body"` |
|||
} |
|||
|
|||
// GithubError represents a Github API error response
|
|||
// https://developer.github.com/v3/#client-errors
|
|||
type GithubError struct { |
|||
Message string `json:"message"` |
|||
Errors []struct { |
|||
Resource string `json:"resource"` |
|||
Field string `json:"field"` |
|||
Code string `json:"code"` |
|||
} `json:"errors"` |
|||
DocumentationURL string `json:"documentation_url"` |
|||
} |
|||
|
|||
func (e GithubError) Error() string { |
|||
return fmt.Sprintf("github: %v %+v %v", e.Message, e.Errors, e.DocumentationURL) |
|||
} |
|||
|
|||
// IssueRequest is a simplified issue request
|
|||
// https://developer.github.com/v3/issues/#create-an-issue
|
|||
type IssueRequest struct { |
|||
Title string `json:"title,omitempty"` |
|||
Body string `json:"body,omitempty"` |
|||
Assignee string `json:"assignee,omitempty"` |
|||
Milestone int `json:"milestone,omitempty"` |
|||
Labels []string `json:"labels,omitempty"` |
|||
} |
|||
|
|||
// IssueListParams are the params for IssueService.List
|
|||
// https://developer.github.com/v3/issues/#parameters
|
|||
type IssueListParams struct { |
|||
Filter string `url:"filter,omitempty"` |
|||
State string `url:"state,omitempty"` |
|||
Labels string `url:"labels,omitempty"` |
|||
Sort string `url:"sort,omitempty"` |
|||
Direction string `url:"direction,omitempty"` |
|||
Since string `url:"since,omitempty"` |
|||
} |
|||
|
|||
// Services
|
|||
|
|||
// IssueService provides methods for creating and reading issues.
|
|||
type IssueService struct { |
|||
sling *sling.Sling |
|||
} |
|||
|
|||
// NewIssueService returns a new IssueService.
|
|||
func NewIssueService(httpClient *http.Client) *IssueService { |
|||
return &IssueService{ |
|||
sling: sling.New().Client(httpClient).Base(baseURL), |
|||
} |
|||
} |
|||
|
|||
// List returns the authenticated user's issues across repos and orgs.
|
|||
func (s *IssueService) List(params *IssueListParams) ([]Issue, *http.Response, error) { |
|||
issues := new([]Issue) |
|||
githubError := new(GithubError) |
|||
resp, err := s.sling.New().Path("issues").QueryStruct(params).Receive(issues, githubError) |
|||
if err == nil { |
|||
err = githubError |
|||
} |
|||
return *issues, resp, err |
|||
} |
|||
|
|||
// ListByRepo returns a repository's issues.
|
|||
func (s *IssueService) ListByRepo(owner, repo string, params *IssueListParams) ([]Issue, *http.Response, error) { |
|||
issues := new([]Issue) |
|||
githubError := new(GithubError) |
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
resp, err := s.sling.New().Get(path).QueryStruct(params).Receive(issues, githubError) |
|||
if err == nil { |
|||
err = githubError |
|||
} |
|||
return *issues, resp, err |
|||
} |
|||
|
|||
// Create creates a new issue on the specified repository.
|
|||
func (s *IssueService) Create(owner, repo string, issueBody *IssueRequest) (*Issue, *http.Response, error) { |
|||
issue := new(Issue) |
|||
githubError := new(GithubError) |
|||
path := fmt.Sprintf("repos/%s/%s/issues", owner, repo) |
|||
resp, err := s.sling.New().Post(path).BodyJSON(issueBody).Receive(issue, githubError) |
|||
if err == nil { |
|||
err = githubError |
|||
} |
|||
return issue, resp, err |
|||
} |
|||
|
|||
// Client to wrap services
|
|||
|
|||
// Client is a tiny Github client
|
|||
type Client struct { |
|||
IssueService *IssueService |
|||
// other service endpoints...
|
|||
} |
|||
|
|||
// NewClient returns a new Client
|
|||
func NewClient(httpClient *http.Client) *Client { |
|||
return &Client{ |
|||
IssueService: NewIssueService(httpClient), |
|||
} |
|||
} |
|||
|
|||
func main() { |
|||
// Github Unauthenticated API
|
|||
client := NewClient(nil) |
|||
params := &IssueListParams{Sort: "updated"} |
|||
issues, _, _ := client.IssueService.ListByRepo("golang", "go", params) |
|||
fmt.Printf("Public golang/go Issues:\n%v\n", issues) |
|||
|
|||
// Github OAuth2 API
|
|||
flags := flag.NewFlagSet("github-example", flag.ExitOnError) |
|||
// -access-token=xxx or GITHUB_ACCESS_TOKEN env var
|
|||
accessToken := flags.String("access-token", "", "Github Access Token") |
|||
flags.Parse(os.Args[1:]) |
|||
flagutil.SetFlagsFromEnv(flags, "GITHUB") |
|||
|
|||
if *accessToken == "" { |
|||
log.Fatal("Github Access Token required to list private issues") |
|||
} |
|||
|
|||
config := &oauth2.Config{} |
|||
token := &oauth2.Token{AccessToken: *accessToken} |
|||
httpClient := config.Client(oauth2.NoContext, token) |
|||
|
|||
client = NewClient(httpClient) |
|||
issues, _, _ = client.IssueService.List(params) |
|||
fmt.Printf("Your Github Issues:\n%v\n", issues) |
|||
|
|||
// body := &IssueRequest{
|
|||
// Title: "Test title",
|
|||
// Body: "Some test issue",
|
|||
// }
|
|||
// issue, _, _ := client.IssueService.Create("dghubble", "temp", body)
|
|||
// fmt.Println(issue)
|
|||
} |
@ -0,0 +1,421 @@ |
|||
package sling |
|||
|
|||
import ( |
|||
"bytes" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"io" |
|||
"io/ioutil" |
|||
"net/http" |
|||
"net/url" |
|||
"strings" |
|||
|
|||
goquery "github.com/google/go-querystring/query" |
|||
) |
|||
|
|||
const ( |
|||
contentType = "Content-Type" |
|||
jsonContentType = "application/json" |
|||
formContentType = "application/x-www-form-urlencoded" |
|||
) |
|||
|
|||
// Doer executes http requests. It is implemented by *http.Client. You can
|
|||
// wrap *http.Client with layers of Doers to form a stack of client-side
|
|||
// middleware.
|
|||
type Doer interface { |
|||
Do(req *http.Request) (*http.Response, error) |
|||
} |
|||
|
|||
// Sling is an HTTP Request builder and sender.
|
|||
type Sling struct { |
|||
// http Client for doing requests
|
|||
httpClient Doer |
|||
// HTTP method (GET, POST, etc.)
|
|||
method string |
|||
// raw url string for requests
|
|||
rawURL string |
|||
// stores key-values pairs to add to request's Headers
|
|||
header http.Header |
|||
// url tagged query structs
|
|||
queryStructs []interface{} |
|||
// json tagged body struct
|
|||
bodyJSON interface{} |
|||
// url tagged body struct (form)
|
|||
bodyForm interface{} |
|||
// simply assigned body
|
|||
body io.ReadCloser |
|||
} |
|||
|
|||
// New returns a new Sling with an http DefaultClient.
|
|||
func New() *Sling { |
|||
return &Sling{ |
|||
httpClient: http.DefaultClient, |
|||
method: "GET", |
|||
header: make(http.Header), |
|||
queryStructs: make([]interface{}, 0), |
|||
} |
|||
} |
|||
|
|||
// New returns a copy of a Sling for creating a new Sling with properties
|
|||
// from a parent Sling. For example,
|
|||
//
|
|||
// parentSling := sling.New().Client(client).Base("https://api.io/")
|
|||
// fooSling := parentSling.New().Get("foo/")
|
|||
// barSling := parentSling.New().Get("bar/")
|
|||
//
|
|||
// fooSling and barSling will both use the same client, but send requests to
|
|||
// https://api.io/foo/ and https://api.io/bar/ respectively.
|
|||
//
|
|||
// Note that query and body values are copied so if pointer values are used,
|
|||
// mutating the original value will mutate the value within the child Sling.
|
|||
func (s *Sling) New() *Sling { |
|||
// copy Headers pairs into new Header map
|
|||
headerCopy := make(http.Header) |
|||
for k, v := range s.header { |
|||
headerCopy[k] = v |
|||
} |
|||
return &Sling{ |
|||
httpClient: s.httpClient, |
|||
method: s.method, |
|||
rawURL: s.rawURL, |
|||
header: headerCopy, |
|||
queryStructs: append([]interface{}{}, s.queryStructs...), |
|||
bodyJSON: s.bodyJSON, |
|||
bodyForm: s.bodyForm, |
|||
body: s.body, |
|||
} |
|||
} |
|||
|
|||
// Http Client
|
|||
|
|||
// Client sets the http Client used to do requests. If a nil client is given,
|
|||
// the http.DefaultClient will be used.
|
|||
func (s *Sling) Client(httpClient *http.Client) *Sling { |
|||
if httpClient == nil { |
|||
return s.Doer(http.DefaultClient) |
|||
} |
|||
return s.Doer(httpClient) |
|||
} |
|||
|
|||
// Doer sets the custom Doer implementation used to do requests.
|
|||
// If a nil client is given, the http.DefaultClient will be used.
|
|||
func (s *Sling) Doer(doer Doer) *Sling { |
|||
if doer == nil { |
|||
s.httpClient = http.DefaultClient |
|||
} else { |
|||
s.httpClient = doer |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// Method
|
|||
|
|||
// Head sets the Sling method to HEAD and sets the given pathURL.
|
|||
func (s *Sling) Head(pathURL string) *Sling { |
|||
s.method = "HEAD" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Get sets the Sling method to GET and sets the given pathURL.
|
|||
func (s *Sling) Get(pathURL string) *Sling { |
|||
s.method = "GET" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Post sets the Sling method to POST and sets the given pathURL.
|
|||
func (s *Sling) Post(pathURL string) *Sling { |
|||
s.method = "POST" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Put sets the Sling method to PUT and sets the given pathURL.
|
|||
func (s *Sling) Put(pathURL string) *Sling { |
|||
s.method = "PUT" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Patch sets the Sling method to PATCH and sets the given pathURL.
|
|||
func (s *Sling) Patch(pathURL string) *Sling { |
|||
s.method = "PATCH" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Delete sets the Sling method to DELETE and sets the given pathURL.
|
|||
func (s *Sling) Delete(pathURL string) *Sling { |
|||
s.method = "DELETE" |
|||
return s.Path(pathURL) |
|||
} |
|||
|
|||
// Header
|
|||
|
|||
// Add adds the key, value pair in Headers, appending values for existing keys
|
|||
// to the key's values. Header keys are canonicalized.
|
|||
func (s *Sling) Add(key, value string) *Sling { |
|||
s.header.Add(key, value) |
|||
return s |
|||
} |
|||
|
|||
// Set sets the key, value pair in Headers, replacing existing values
|
|||
// associated with key. Header keys are canonicalized.
|
|||
func (s *Sling) Set(key, value string) *Sling { |
|||
s.header.Set(key, value) |
|||
return s |
|||
} |
|||
|
|||
// SetBasicAuth sets the Authorization header to use HTTP Basic Authentication
|
|||
// with the provided username and password. With HTTP Basic Authentication
|
|||
// the provided username and password are not encrypted.
|
|||
func (s *Sling) SetBasicAuth(username, password string) *Sling { |
|||
return s.Set("Authorization", "Basic "+basicAuth(username, password)) |
|||
} |
|||
|
|||
// basicAuth returns the base64 encoded username:password for basic auth copied
|
|||
// from net/http.
|
|||
func basicAuth(username, password string) string { |
|||
auth := username + ":" + password |
|||
return base64.StdEncoding.EncodeToString([]byte(auth)) |
|||
} |
|||
|
|||
// Url
|
|||
|
|||
// Base sets the rawURL. If you intend to extend the url with Path,
|
|||
// baseUrl should be specified with a trailing slash.
|
|||
func (s *Sling) Base(rawURL string) *Sling { |
|||
s.rawURL = rawURL |
|||
return s |
|||
} |
|||
|
|||
// Path extends the rawURL with the given path by resolving the reference to
|
|||
// an absolute URL. If parsing errors occur, the rawURL is left unmodified.
|
|||
func (s *Sling) Path(path string) *Sling { |
|||
baseURL, baseErr := url.Parse(s.rawURL) |
|||
pathURL, pathErr := url.Parse(path) |
|||
if baseErr == nil && pathErr == nil { |
|||
s.rawURL = baseURL.ResolveReference(pathURL).String() |
|||
return s |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// QueryStruct appends the queryStruct to the Sling's queryStructs. The value
|
|||
// pointed to by each queryStruct will be encoded as url query parameters on
|
|||
// new requests (see Request()).
|
|||
// The queryStruct argument should be a pointer to a url tagged struct. See
|
|||
// https://godoc.org/github.com/google/go-querystring/query for details.
|
|||
func (s *Sling) QueryStruct(queryStruct interface{}) *Sling { |
|||
if queryStruct != nil { |
|||
s.queryStructs = append(s.queryStructs, queryStruct) |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// Body
|
|||
|
|||
// BodyJSON sets the Sling's bodyJSON. The value pointed to by the bodyJSON
|
|||
// will be JSON encoded as the Body on new requests (see Request()).
|
|||
// The bodyJSON argument should be a pointer to a JSON tagged struct. See
|
|||
// https://golang.org/pkg/encoding/json/#MarshalIndent for details.
|
|||
func (s *Sling) BodyJSON(bodyJSON interface{}) *Sling { |
|||
if bodyJSON != nil { |
|||
s.bodyJSON = bodyJSON |
|||
s.Set(contentType, jsonContentType) |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// BodyForm sets the Sling's bodyForm. The value pointed to by the bodyForm
|
|||
// will be url encoded as the Body on new requests (see Request()).
|
|||
// The bodyStruct argument should be a pointer to a url tagged struct. See
|
|||
// https://godoc.org/github.com/google/go-querystring/query for details.
|
|||
func (s *Sling) BodyForm(bodyForm interface{}) *Sling { |
|||
if bodyForm != nil { |
|||
s.bodyForm = bodyForm |
|||
s.Set(contentType, formContentType) |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// Body sets the Sling's body. The body value will be set as the Body on new
|
|||
// requests (see Request()).
|
|||
// If the provided body is also an io.Closer, the request Body will be closed
|
|||
// by http.Client methods.
|
|||
func (s *Sling) Body(body io.Reader) *Sling { |
|||
rc, ok := body.(io.ReadCloser) |
|||
if !ok && body != nil { |
|||
rc = ioutil.NopCloser(body) |
|||
} |
|||
if rc != nil { |
|||
s.body = rc |
|||
} |
|||
return s |
|||
} |
|||
|
|||
// Requests
|
|||
|
|||
// Request returns a new http.Request created with the Sling properties.
|
|||
// Returns any errors parsing the rawURL, encoding query structs, encoding
|
|||
// the body, or creating the http.Request.
|
|||
func (s *Sling) Request() (*http.Request, error) { |
|||
reqURL, err := url.Parse(s.rawURL) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
err = addQueryStructs(reqURL, s.queryStructs) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
body, err := s.getRequestBody() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
req, err := http.NewRequest(s.method, reqURL.String(), body) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
addHeaders(req, s.header) |
|||
return req, err |
|||
} |
|||
|
|||
// addQueryStructs parses url tagged query structs using go-querystring to
|
|||
// encode them to url.Values and format them onto the url.RawQuery. Any
|
|||
// query parsing or encoding errors are returned.
|
|||
func addQueryStructs(reqURL *url.URL, queryStructs []interface{}) error { |
|||
urlValues, err := url.ParseQuery(reqURL.RawQuery) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// encodes query structs into a url.Values map and merges maps
|
|||
for _, queryStruct := range queryStructs { |
|||
queryValues, err := goquery.Values(queryStruct) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
for key, values := range queryValues { |
|||
for _, value := range values { |
|||
urlValues.Add(key, value) |
|||
} |
|||
} |
|||
} |
|||
// url.Values format to a sorted "url encoded" string, e.g. "key=val&foo=bar"
|
|||
reqURL.RawQuery = urlValues.Encode() |
|||
return nil |
|||
} |
|||
|
|||
// getRequestBody returns the io.Reader which should be used as the body
|
|||
// of new Requests.
|
|||
func (s *Sling) getRequestBody() (body io.Reader, err error) { |
|||
if s.bodyJSON != nil && s.header.Get(contentType) == jsonContentType { |
|||
body, err = encodeBodyJSON(s.bodyJSON) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} else if s.bodyForm != nil && s.header.Get(contentType) == formContentType { |
|||
body, err = encodeBodyForm(s.bodyForm) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} else if s.body != nil { |
|||
body = s.body |
|||
} |
|||
return body, nil |
|||
} |
|||
|
|||
// encodeBodyJSON JSON encodes the value pointed to by bodyJSON into an
|
|||
// io.Reader, typically for use as a Request Body.
|
|||
func encodeBodyJSON(bodyJSON interface{}) (io.Reader, error) { |
|||
var buf = new(bytes.Buffer) |
|||
if bodyJSON != nil { |
|||
buf = &bytes.Buffer{} |
|||
err := json.NewEncoder(buf).Encode(bodyJSON) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
return buf, nil |
|||
} |
|||
|
|||
// encodeBodyForm url encodes the value pointed to by bodyForm into an
|
|||
// io.Reader, typically for use as a Request Body.
|
|||
func encodeBodyForm(bodyForm interface{}) (io.Reader, error) { |
|||
values, err := goquery.Values(bodyForm) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return strings.NewReader(values.Encode()), nil |
|||
} |
|||
|
|||
// addHeaders adds the key, value pairs from the given http.Header to the
|
|||
// request. Values for existing keys are appended to the keys values.
|
|||
func addHeaders(req *http.Request, header http.Header) { |
|||
for key, values := range header { |
|||
for _, value := range values { |
|||
req.Header.Add(key, value) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Sending
|
|||
|
|||
// ReceiveSuccess creates a new HTTP request and returns the response. Success
|
|||
// responses (2XX) are JSON decoded into the value pointed to by successV.
|
|||
// Any error creating the request, sending it, or decoding a 2XX response
|
|||
// is returned.
|
|||
func (s *Sling) ReceiveSuccess(successV interface{}) (*http.Response, error) { |
|||
return s.Receive(successV, nil) |
|||
} |
|||
|
|||
// Receive creates a new HTTP request and returns the response. Success
|
|||
// responses (2XX) are JSON decoded into the value pointed to by successV and
|
|||
// other responses are JSON decoded into the value pointed to by failureV.
|
|||
// Any error creating the request, sending it, or decoding the response is
|
|||
// returned.
|
|||
// Receive is shorthand for calling Request and Do.
|
|||
func (s *Sling) Receive(successV, failureV interface{}) (*http.Response, error) { |
|||
req, err := s.Request() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return s.Do(req, successV, failureV) |
|||
} |
|||
|
|||
// Do sends an HTTP request and returns the response. Success responses (2XX)
|
|||
// are JSON decoded into the value pointed to by successV and other responses
|
|||
// are JSON decoded into the value pointed to by failureV.
|
|||
// Any error sending the request or decoding the response is returned.
|
|||
func (s *Sling) Do(req *http.Request, successV, failureV interface{}) (*http.Response, error) { |
|||
resp, err := s.httpClient.Do(req) |
|||
if err != nil { |
|||
return resp, err |
|||
} |
|||
// when err is nil, resp contains a non-nil resp.Body which must be closed
|
|||
defer resp.Body.Close() |
|||
if strings.Contains(resp.Header.Get(contentType), jsonContentType) { |
|||
err = decodeResponseJSON(resp, successV, failureV) |
|||
} |
|||
return resp, err |
|||
} |
|||
|
|||
// decodeResponse decodes response Body into the value pointed to by successV
|
|||
// if the response is a success (2XX) or into the value pointed to by failureV
|
|||
// otherwise. If the successV or failureV argument to decode into is nil,
|
|||
// decoding is skipped.
|
|||
// Caller is responsible for closing the resp.Body.
|
|||
func decodeResponseJSON(resp *http.Response, successV, failureV interface{}) error { |
|||
if code := resp.StatusCode; 200 <= code && code <= 299 { |
|||
if successV != nil { |
|||
return decodeResponseBodyJSON(resp, successV) |
|||
} |
|||
} else { |
|||
if failureV != nil { |
|||
return decodeResponseBodyJSON(resp, failureV) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// decodeResponseBodyJSON JSON decodes a Response Body into the value pointed
|
|||
// to by v.
|
|||
// Caller must provide a non-nil v and close the resp.Body.
|
|||
func decodeResponseBodyJSON(resp *http.Response, v interface{}) error { |
|||
return json.NewDecoder(resp.Body).Decode(v) |
|||
} |
@ -0,0 +1,863 @@ |
|||
package sling |
|||
|
|||
import ( |
|||
"bytes" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"io/ioutil" |
|||
"math" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"net/url" |
|||
"reflect" |
|||
"strings" |
|||
"testing" |
|||
) |
|||
|
|||
type FakeParams struct { |
|||
KindName string `url:"kind_name"` |
|||
Count int `url:"count"` |
|||
} |
|||
|
|||
// Url-tagged query struct
|
|||
var paramsA = struct { |
|||
Limit int `url:"limit"` |
|||
}{ |
|||
30, |
|||
} |
|||
var paramsB = FakeParams{KindName: "recent", Count: 25} |
|||
|
|||
// Json-tagged model struct
|
|||
type FakeModel struct { |
|||
Text string `json:"text,omitempty"` |
|||
FavoriteCount int64 `json:"favorite_count,omitempty"` |
|||
Temperature float64 `json:"temperature,omitempty"` |
|||
} |
|||
|
|||
var modelA = FakeModel{Text: "note", FavoriteCount: 12} |
|||
|
|||
func TestNew(t *testing.T) { |
|||
sling := New() |
|||
if sling.httpClient != http.DefaultClient { |
|||
t.Errorf("expected %v, got %v", http.DefaultClient, sling.httpClient) |
|||
} |
|||
if sling.header == nil { |
|||
t.Errorf("Header map not initialized with make") |
|||
} |
|||
if sling.queryStructs == nil { |
|||
t.Errorf("queryStructs not initialized with make") |
|||
} |
|||
} |
|||
|
|||
func TestSlingNew(t *testing.T) { |
|||
cases := []*Sling{ |
|||
&Sling{httpClient: &http.Client{}, method: "GET", rawURL: "http://example.com"}, |
|||
&Sling{httpClient: nil, method: "", rawURL: "http://example.com"}, |
|||
&Sling{queryStructs: make([]interface{}, 0)}, |
|||
&Sling{queryStructs: []interface{}{paramsA}}, |
|||
&Sling{queryStructs: []interface{}{paramsA, paramsB}}, |
|||
&Sling{bodyJSON: &FakeModel{Text: "a"}}, |
|||
&Sling{bodyJSON: FakeModel{Text: "a"}}, |
|||
&Sling{bodyJSON: nil}, |
|||
New().Add("Content-Type", "application/json"), |
|||
New().Add("A", "B").Add("a", "c").New(), |
|||
New().Add("A", "B").New().Add("a", "c"), |
|||
New().BodyForm(paramsB), |
|||
New().BodyForm(paramsB).New(), |
|||
} |
|||
for _, sling := range cases { |
|||
child := sling.New() |
|||
if child.httpClient != sling.httpClient { |
|||
t.Errorf("expected %v, got %v", sling.httpClient, child.httpClient) |
|||
} |
|||
if child.method != sling.method { |
|||
t.Errorf("expected %s, got %s", sling.method, child.method) |
|||
} |
|||
if child.rawURL != sling.rawURL { |
|||
t.Errorf("expected %s, got %s", sling.rawURL, child.rawURL) |
|||
} |
|||
// Header should be a copy of parent Sling header. For example, calling
|
|||
// baseSling.Add("k","v") should not mutate previously created child Slings
|
|||
if sling.header != nil { |
|||
// struct literal cases don't init Header in usual way, skip header check
|
|||
if !reflect.DeepEqual(sling.header, child.header) { |
|||
t.Errorf("not DeepEqual: expected %v, got %v", sling.header, child.header) |
|||
} |
|||
sling.header.Add("K", "V") |
|||
if child.header.Get("K") != "" { |
|||
t.Errorf("child.header was a reference to original map, should be copy") |
|||
} |
|||
} |
|||
// queryStruct slice should be a new slice with a copy of the contents
|
|||
if len(sling.queryStructs) > 0 { |
|||
// mutating one slice should not mutate the other
|
|||
child.queryStructs[0] = nil |
|||
if sling.queryStructs[0] == nil { |
|||
t.Errorf("child.queryStructs was a re-slice, expected slice with copied contents") |
|||
} |
|||
} |
|||
// bodyJSON should be copied
|
|||
if child.bodyJSON != sling.bodyJSON { |
|||
t.Errorf("expected %v, got %v", sling.bodyJSON, child.bodyJSON) |
|||
} |
|||
// bodyForm should be copied
|
|||
if child.bodyForm != sling.bodyForm { |
|||
t.Errorf("expected %v, got %v", sling.bodyForm, child.bodyForm) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestClientSetter(t *testing.T) { |
|||
developerClient := &http.Client{} |
|||
cases := []struct { |
|||
input *http.Client |
|||
expected *http.Client |
|||
}{ |
|||
{nil, http.DefaultClient}, |
|||
{developerClient, developerClient}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New() |
|||
sling.Client(c.input) |
|||
if sling.httpClient != c.expected { |
|||
t.Errorf("input %v, expected %v, got %v", c.input, c.expected, sling.httpClient) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestDoerSetter(t *testing.T) { |
|||
developerClient := &http.Client{} |
|||
cases := []struct { |
|||
input Doer |
|||
expected Doer |
|||
}{ |
|||
{nil, http.DefaultClient}, |
|||
{developerClient, developerClient}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New() |
|||
sling.Doer(c.input) |
|||
if sling.httpClient != c.expected { |
|||
t.Errorf("input %v, expected %v, got %v", c.input, c.expected, sling.httpClient) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestBaseSetter(t *testing.T) { |
|||
cases := []string{"http://a.io/", "http://b.io", "/path", "path", ""} |
|||
for _, base := range cases { |
|||
sling := New().Base(base) |
|||
if sling.rawURL != base { |
|||
t.Errorf("expected %s, got %s", base, sling.rawURL) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestPathSetter(t *testing.T) { |
|||
cases := []struct { |
|||
rawURL string |
|||
path string |
|||
expectedRawURL string |
|||
}{ |
|||
{"http://a.io/", "foo", "http://a.io/foo"}, |
|||
{"http://a.io/", "/foo", "http://a.io/foo"}, |
|||
{"http://a.io", "foo", "http://a.io/foo"}, |
|||
{"http://a.io", "/foo", "http://a.io/foo"}, |
|||
{"http://a.io/foo/", "bar", "http://a.io/foo/bar"}, |
|||
// rawURL should end in trailing slash if it is to be Path extended
|
|||
{"http://a.io/foo", "bar", "http://a.io/bar"}, |
|||
{"http://a.io/foo", "/bar", "http://a.io/bar"}, |
|||
// path extension is absolute
|
|||
{"http://a.io", "http://b.io/", "http://b.io/"}, |
|||
{"http://a.io/", "http://b.io/", "http://b.io/"}, |
|||
{"http://a.io", "http://b.io", "http://b.io"}, |
|||
{"http://a.io/", "http://b.io", "http://b.io"}, |
|||
// empty base, empty path
|
|||
{"", "http://b.io", "http://b.io"}, |
|||
{"http://a.io", "", "http://a.io"}, |
|||
{"", "", ""}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New().Base(c.rawURL).Path(c.path) |
|||
if sling.rawURL != c.expectedRawURL { |
|||
t.Errorf("expected %s, got %s", c.expectedRawURL, sling.rawURL) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestMethodSetters(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedMethod string |
|||
}{ |
|||
{New().Path("http://a.io"), "GET"}, |
|||
{New().Head("http://a.io"), "HEAD"}, |
|||
{New().Get("http://a.io"), "GET"}, |
|||
{New().Post("http://a.io"), "POST"}, |
|||
{New().Put("http://a.io"), "PUT"}, |
|||
{New().Patch("http://a.io"), "PATCH"}, |
|||
{New().Delete("http://a.io"), "DELETE"}, |
|||
} |
|||
for _, c := range cases { |
|||
if c.sling.method != c.expectedMethod { |
|||
t.Errorf("expected method %s, got %s", c.expectedMethod, c.sling.method) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestAddHeader(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedHeader map[string][]string |
|||
}{ |
|||
{New().Add("authorization", "OAuth key=\"value\""), map[string][]string{"Authorization": []string{"OAuth key=\"value\""}}}, |
|||
// header keys should be canonicalized
|
|||
{New().Add("content-tYPE", "application/json").Add("User-AGENT", "sling"), map[string][]string{"Content-Type": []string{"application/json"}, "User-Agent": []string{"sling"}}}, |
|||
// values for existing keys should be appended
|
|||
{New().Add("A", "B").Add("a", "c"), map[string][]string{"A": []string{"B", "c"}}}, |
|||
// Add should add to values for keys added by parent Slings
|
|||
{New().Add("A", "B").Add("a", "c").New(), map[string][]string{"A": []string{"B", "c"}}}, |
|||
{New().Add("A", "B").New().Add("a", "c"), map[string][]string{"A": []string{"B", "c"}}}, |
|||
} |
|||
for _, c := range cases { |
|||
// type conversion from header to alias'd map for deep equality comparison
|
|||
headerMap := map[string][]string(c.sling.header) |
|||
if !reflect.DeepEqual(c.expectedHeader, headerMap) { |
|||
t.Errorf("not DeepEqual: expected %v, got %v", c.expectedHeader, headerMap) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestSetHeader(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedHeader map[string][]string |
|||
}{ |
|||
// should replace existing values associated with key
|
|||
{New().Add("A", "B").Set("a", "c"), map[string][]string{"A": []string{"c"}}}, |
|||
{New().Set("content-type", "A").Set("Content-Type", "B"), map[string][]string{"Content-Type": []string{"B"}}}, |
|||
// Set should replace values received by copying parent Slings
|
|||
{New().Set("A", "B").Add("a", "c").New(), map[string][]string{"A": []string{"B", "c"}}}, |
|||
{New().Add("A", "B").New().Set("a", "c"), map[string][]string{"A": []string{"c"}}}, |
|||
} |
|||
for _, c := range cases { |
|||
// type conversion from Header to alias'd map for deep equality comparison
|
|||
headerMap := map[string][]string(c.sling.header) |
|||
if !reflect.DeepEqual(c.expectedHeader, headerMap) { |
|||
t.Errorf("not DeepEqual: expected %v, got %v", c.expectedHeader, headerMap) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestBasicAuth(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedAuth []string |
|||
}{ |
|||
// basic auth: username & password
|
|||
{New().SetBasicAuth("Aladdin", "open sesame"), []string{"Aladdin", "open sesame"}}, |
|||
// empty username
|
|||
{New().SetBasicAuth("", "secret"), []string{"", "secret"}}, |
|||
// empty password
|
|||
{New().SetBasicAuth("admin", ""), []string{"admin", ""}}, |
|||
} |
|||
for _, c := range cases { |
|||
req, err := c.sling.Request() |
|||
if err != nil { |
|||
t.Errorf("unexpected error when building Request with .SetBasicAuth()") |
|||
} |
|||
username, password, ok := req.BasicAuth() |
|||
if !ok { |
|||
t.Errorf("basic auth missing when expected") |
|||
} |
|||
auth := []string{username, password} |
|||
if !reflect.DeepEqual(c.expectedAuth, auth) { |
|||
t.Errorf("not DeepEqual: expected %v, got %v", c.expectedAuth, auth) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestQueryStructSetter(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedStructs []interface{} |
|||
}{ |
|||
{New(), []interface{}{}}, |
|||
{New().QueryStruct(nil), []interface{}{}}, |
|||
{New().QueryStruct(paramsA), []interface{}{paramsA}}, |
|||
{New().QueryStruct(paramsA).QueryStruct(paramsA), []interface{}{paramsA, paramsA}}, |
|||
{New().QueryStruct(paramsA).QueryStruct(paramsB), []interface{}{paramsA, paramsB}}, |
|||
{New().QueryStruct(paramsA).New(), []interface{}{paramsA}}, |
|||
{New().QueryStruct(paramsA).New().QueryStruct(paramsB), []interface{}{paramsA, paramsB}}, |
|||
} |
|||
|
|||
for _, c := range cases { |
|||
if count := len(c.sling.queryStructs); count != len(c.expectedStructs) { |
|||
t.Errorf("expected length %d, got %d", len(c.expectedStructs), count) |
|||
} |
|||
check: |
|||
for _, expected := range c.expectedStructs { |
|||
for _, param := range c.sling.queryStructs { |
|||
if param == expected { |
|||
continue check |
|||
} |
|||
} |
|||
t.Errorf("expected to find %v in %v", expected, c.sling.queryStructs) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestBodyJSONSetter(t *testing.T) { |
|||
fakeModel := &FakeModel{} |
|||
cases := []struct { |
|||
initial interface{} |
|||
input interface{} |
|||
expected interface{} |
|||
}{ |
|||
// json tagged struct is set as bodyJSON
|
|||
{nil, fakeModel, fakeModel}, |
|||
// nil argument to bodyJSON does not replace existing bodyJSON
|
|||
{fakeModel, nil, fakeModel}, |
|||
// nil bodyJSON remains nil
|
|||
{nil, nil, nil}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New() |
|||
sling.bodyJSON = c.initial |
|||
sling.BodyJSON(c.input) |
|||
if sling.bodyJSON != c.expected { |
|||
t.Errorf("expected %v, got %v", c.expected, sling.bodyJSON) |
|||
} |
|||
// Header Content-Type should be application/json if bodyJSON arg was non-nil
|
|||
if c.input != nil && sling.header.Get(contentType) != jsonContentType { |
|||
t.Errorf("Incorrect or missing header, expected %s, got %s", jsonContentType, sling.header.Get(contentType)) |
|||
} else if c.input == nil && sling.header.Get(contentType) != "" { |
|||
t.Errorf("did not expect a Content-Type header, got %s", sling.header.Get(contentType)) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestBodyFormSetter(t *testing.T) { |
|||
cases := []struct { |
|||
initial interface{} |
|||
input interface{} |
|||
expected interface{} |
|||
}{ |
|||
// url tagged struct is set as bodyStruct
|
|||
{nil, paramsB, paramsB}, |
|||
// nil argument to bodyStruct does not replace existing bodyStruct
|
|||
{paramsB, nil, paramsB}, |
|||
// nil bodyStruct remains nil
|
|||
{nil, nil, nil}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New() |
|||
sling.bodyForm = c.initial |
|||
sling.BodyForm(c.input) |
|||
if sling.bodyForm != c.expected { |
|||
t.Errorf("expected %v, got %v", c.expected, sling.bodyForm) |
|||
} |
|||
// Content-Type should be application/x-www-form-urlencoded if bodyStruct was non-nil
|
|||
if c.input != nil && sling.header.Get(contentType) != formContentType { |
|||
t.Errorf("Incorrect or missing header, expected %s, got %s", formContentType, sling.header.Get(contentType)) |
|||
} else if c.input == nil && sling.header.Get(contentType) != "" { |
|||
t.Errorf("did not expect a Content-Type header, got %s", sling.header.Get(contentType)) |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
func TestBodySetter(t *testing.T) { |
|||
var testInput = ioutil.NopCloser(strings.NewReader("test")) |
|||
cases := []struct { |
|||
initial io.ReadCloser |
|||
input io.Reader |
|||
expected io.Reader |
|||
}{ |
|||
// nil body is overriden by a set body
|
|||
{nil, testInput, testInput}, |
|||
// initial body is not overriden by nil body
|
|||
{testInput, nil, testInput}, |
|||
// nil body is returned unaltered
|
|||
{nil, nil, nil}, |
|||
} |
|||
for _, c := range cases { |
|||
sling := New() |
|||
sling.body = c.initial |
|||
sling.Body(c.input) |
|||
body, err := sling.getRequestBody() |
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if body != c.expected { |
|||
t.Errorf("expected %v, got %v", c.expected, body) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_urlAndMethod(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedMethod string |
|||
expectedURL string |
|||
expectedErr error |
|||
}{ |
|||
{New().Base("http://a.io"), "GET", "http://a.io", nil}, |
|||
{New().Path("http://a.io"), "GET", "http://a.io", nil}, |
|||
{New().Get("http://a.io"), "GET", "http://a.io", nil}, |
|||
{New().Put("http://a.io"), "PUT", "http://a.io", nil}, |
|||
{New().Base("http://a.io/").Path("foo"), "GET", "http://a.io/foo", nil}, |
|||
{New().Base("http://a.io/").Post("foo"), "POST", "http://a.io/foo", nil}, |
|||
// if relative path is an absolute url, base is ignored
|
|||
{New().Base("http://a.io").Path("http://b.io"), "GET", "http://b.io", nil}, |
|||
{New().Path("http://a.io").Path("http://b.io"), "GET", "http://b.io", nil}, |
|||
// last method setter takes priority
|
|||
{New().Get("http://b.io").Post("http://a.io"), "POST", "http://a.io", nil}, |
|||
{New().Post("http://a.io/").Put("foo/").Delete("bar"), "DELETE", "http://a.io/foo/bar", nil}, |
|||
// last Base setter takes priority
|
|||
{New().Base("http://a.io").Base("http://b.io"), "GET", "http://b.io", nil}, |
|||
// Path setters are additive
|
|||
{New().Base("http://a.io/").Path("foo/").Path("bar"), "GET", "http://a.io/foo/bar", nil}, |
|||
{New().Path("http://a.io/").Path("foo/").Path("bar"), "GET", "http://a.io/foo/bar", nil}, |
|||
// removes extra '/' between base and ref url
|
|||
{New().Base("http://a.io/").Get("/foo"), "GET", "http://a.io/foo", nil}, |
|||
} |
|||
for _, c := range cases { |
|||
req, err := c.sling.Request() |
|||
if err != c.expectedErr { |
|||
t.Errorf("expected error %v, got %v for %+v", c.expectedErr, err, c.sling) |
|||
} |
|||
if req.URL.String() != c.expectedURL { |
|||
t.Errorf("expected url %s, got %s for %+v", c.expectedURL, req.URL.String(), c.sling) |
|||
} |
|||
if req.Method != c.expectedMethod { |
|||
t.Errorf("expected method %s, got %s for %+v", c.expectedMethod, req.Method, c.sling) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_queryStructs(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedURL string |
|||
}{ |
|||
{New().Base("http://a.io").QueryStruct(paramsA), "http://a.io?limit=30"}, |
|||
{New().Base("http://a.io").QueryStruct(paramsA).QueryStruct(paramsB), "http://a.io?count=25&kind_name=recent&limit=30"}, |
|||
{New().Base("http://a.io/").Path("foo?path=yes").QueryStruct(paramsA), "http://a.io/foo?limit=30&path=yes"}, |
|||
{New().Base("http://a.io").QueryStruct(paramsA).New(), "http://a.io?limit=30"}, |
|||
{New().Base("http://a.io").QueryStruct(paramsA).New().QueryStruct(paramsB), "http://a.io?count=25&kind_name=recent&limit=30"}, |
|||
} |
|||
for _, c := range cases { |
|||
req, _ := c.sling.Request() |
|||
if req.URL.String() != c.expectedURL { |
|||
t.Errorf("expected url %s, got %s for %+v", c.expectedURL, req.URL.String(), c.sling) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_body(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedBody string // expected Body io.Reader as a string
|
|||
expectedContentType string |
|||
}{ |
|||
// BodyJSON
|
|||
{New().BodyJSON(modelA), "{\"text\":\"note\",\"favorite_count\":12}\n", jsonContentType}, |
|||
{New().BodyJSON(&modelA), "{\"text\":\"note\",\"favorite_count\":12}\n", jsonContentType}, |
|||
{New().BodyJSON(&FakeModel{}), "{}\n", jsonContentType}, |
|||
{New().BodyJSON(FakeModel{}), "{}\n", jsonContentType}, |
|||
// BodyJSON overrides existing values
|
|||
{New().BodyJSON(&FakeModel{}).BodyJSON(&FakeModel{Text: "msg"}), "{\"text\":\"msg\"}\n", jsonContentType}, |
|||
// BodyForm
|
|||
{New().BodyForm(paramsA), "limit=30", formContentType}, |
|||
{New().BodyForm(paramsB), "count=25&kind_name=recent", formContentType}, |
|||
{New().BodyForm(¶msB), "count=25&kind_name=recent", formContentType}, |
|||
// BodyForm overrides existing values
|
|||
{New().BodyForm(paramsA).New().BodyForm(paramsB), "count=25&kind_name=recent", formContentType}, |
|||
// Mixture of BodyJSON and BodyForm prefers body setter called last with a non-nil argument
|
|||
{New().BodyForm(paramsB).New().BodyJSON(modelA), "{\"text\":\"note\",\"favorite_count\":12}\n", jsonContentType}, |
|||
{New().BodyJSON(modelA).New().BodyForm(paramsB), "count=25&kind_name=recent", formContentType}, |
|||
{New().BodyForm(paramsB).New().BodyJSON(nil), "count=25&kind_name=recent", formContentType}, |
|||
{New().BodyJSON(modelA).New().BodyForm(nil), "{\"text\":\"note\",\"favorite_count\":12}\n", jsonContentType}, |
|||
// Body
|
|||
{New().Body(strings.NewReader("this-is-a-test")), "this-is-a-test", ""}, |
|||
{New().Body(strings.NewReader("a")).Body(strings.NewReader("b")), "b", ""}, |
|||
} |
|||
for _, c := range cases { |
|||
req, _ := c.sling.Request() |
|||
buf := new(bytes.Buffer) |
|||
buf.ReadFrom(req.Body) |
|||
// req.Body should have contained the expectedBody string
|
|||
if value := buf.String(); value != c.expectedBody { |
|||
t.Errorf("expected Request.Body %s, got %s", c.expectedBody, value) |
|||
} |
|||
// Header Content-Type should be expectedContentType ("" means no contentType expected)
|
|||
if actualHeader := req.Header.Get(contentType); actualHeader != c.expectedContentType && c.expectedContentType != "" { |
|||
t.Errorf("Incorrect or missing header, expected %s, got %s", c.expectedContentType, actualHeader) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_bodyNoData(t *testing.T) { |
|||
// test that Body is left nil when no bodyJSON or bodyStruct set
|
|||
slings := []*Sling{ |
|||
New(), |
|||
New().BodyJSON(nil), |
|||
New().BodyForm(nil), |
|||
} |
|||
for _, sling := range slings { |
|||
req, _ := sling.Request() |
|||
if req.Body != nil { |
|||
t.Errorf("expected nil Request.Body, got %v", req.Body) |
|||
} |
|||
// Header Content-Type should not be set when bodyJSON argument was nil or never called
|
|||
if actualHeader := req.Header.Get(contentType); actualHeader != "" { |
|||
t.Errorf("did not expect a Content-Type header, got %s", actualHeader) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_bodyEncodeErrors(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedErr error |
|||
}{ |
|||
// check that Encode errors are propagated, illegal JSON field
|
|||
{New().BodyJSON(FakeModel{Temperature: math.Inf(1)}), errors.New("json: unsupported value: +Inf")}, |
|||
} |
|||
for _, c := range cases { |
|||
req, err := c.sling.Request() |
|||
if err == nil || err.Error() != c.expectedErr.Error() { |
|||
t.Errorf("expected error %v, got %v", c.expectedErr, err) |
|||
} |
|||
if req != nil { |
|||
t.Errorf("expected nil Request, got %+v", req) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestRequest_headers(t *testing.T) { |
|||
cases := []struct { |
|||
sling *Sling |
|||
expectedHeader map[string][]string |
|||
}{ |
|||
{New().Add("authorization", "OAuth key=\"value\""), map[string][]string{"Authorization": []string{"OAuth key=\"value\""}}}, |
|||
// header keys should be canonicalized
|
|||
{New().Add("content-tYPE", "application/json").Add("User-AGENT", "sling"), map[string][]string{"Content-Type": []string{"application/json"}, "User-Agent": []string{"sling"}}}, |
|||
// values for existing keys should be appended
|
|||
{New().Add("A", "B").Add("a", "c"), map[string][]string{"A": []string{"B", "c"}}}, |
|||
// Add should add to values for keys added by parent Slings
|
|||
{New().Add("A", "B").Add("a", "c").New(), map[string][]string{"A": []string{"B", "c"}}}, |
|||
{New().Add("A", "B").New().Add("a", "c"), map[string][]string{"A": []string{"B", "c"}}}, |
|||
// Add and Set
|
|||
{New().Add("A", "B").Set("a", "c"), map[string][]string{"A": []string{"c"}}}, |
|||
{New().Set("content-type", "A").Set("Content-Type", "B"), map[string][]string{"Content-Type": []string{"B"}}}, |
|||
// Set should replace values received by copying parent Slings
|
|||
{New().Set("A", "B").Add("a", "c").New(), map[string][]string{"A": []string{"B", "c"}}}, |
|||
{New().Add("A", "B").New().Set("a", "c"), map[string][]string{"A": []string{"c"}}}, |
|||
} |
|||
for _, c := range cases { |
|||
req, _ := c.sling.Request() |
|||
// type conversion from Header to alias'd map for deep equality comparison
|
|||
headerMap := map[string][]string(req.Header) |
|||
if !reflect.DeepEqual(c.expectedHeader, headerMap) { |
|||
t.Errorf("not DeepEqual: expected %v, got %v", c.expectedHeader, headerMap) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestAddQueryStructs(t *testing.T) { |
|||
cases := []struct { |
|||
rawurl string |
|||
queryStructs []interface{} |
|||
expected string |
|||
}{ |
|||
{"http://a.io", []interface{}{}, "http://a.io"}, |
|||
{"http://a.io", []interface{}{paramsA}, "http://a.io?limit=30"}, |
|||
{"http://a.io", []interface{}{paramsA, paramsA}, "http://a.io?limit=30&limit=30"}, |
|||
{"http://a.io", []interface{}{paramsA, paramsB}, "http://a.io?count=25&kind_name=recent&limit=30"}, |
|||
// don't blow away query values on the rawURL (parsed into RawQuery)
|
|||
{"http://a.io?initial=7", []interface{}{paramsA}, "http://a.io?initial=7&limit=30"}, |
|||
} |
|||
for _, c := range cases { |
|||
reqURL, _ := url.Parse(c.rawurl) |
|||
addQueryStructs(reqURL, c.queryStructs) |
|||
if reqURL.String() != c.expected { |
|||
t.Errorf("expected %s, got %s", c.expected, reqURL.String()) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Sending
|
|||
|
|||
type APIError struct { |
|||
Message string `json:"message"` |
|||
Code int `json:"code"` |
|||
} |
|||
|
|||
func TestDo_onSuccess(t *testing.T) { |
|||
const expectedText = "Some text" |
|||
const expectedFavoriteCount int64 = 24 |
|||
|
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"text": "Some text", "favorite_count": 24}`) |
|||
}) |
|||
|
|||
sling := New().Client(client) |
|||
req, _ := http.NewRequest("GET", "http://example.com/success", nil) |
|||
|
|||
model := new(FakeModel) |
|||
apiError := new(APIError) |
|||
resp, err := sling.Do(req, model, apiError) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 200 { |
|||
t.Errorf("expected %d, got %d", 200, resp.StatusCode) |
|||
} |
|||
if model.Text != expectedText { |
|||
t.Errorf("expected %s, got %s", expectedText, model.Text) |
|||
} |
|||
if model.FavoriteCount != expectedFavoriteCount { |
|||
t.Errorf("expected %d, got %d", expectedFavoriteCount, model.FavoriteCount) |
|||
} |
|||
} |
|||
|
|||
func TestDo_onSuccessWithNilValue(t *testing.T) { |
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"text": "Some text", "favorite_count": 24}`) |
|||
}) |
|||
|
|||
sling := New().Client(client) |
|||
req, _ := http.NewRequest("GET", "http://example.com/success", nil) |
|||
|
|||
apiError := new(APIError) |
|||
resp, err := sling.Do(req, nil, apiError) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 200 { |
|||
t.Errorf("expected %d, got %d", 200, resp.StatusCode) |
|||
} |
|||
expected := &APIError{} |
|||
if !reflect.DeepEqual(expected, apiError) { |
|||
t.Errorf("failureV should not be populated, exepcted %v, got %v", expected, apiError) |
|||
} |
|||
} |
|||
|
|||
func TestDo_onFailure(t *testing.T) { |
|||
const expectedMessage = "Invalid argument" |
|||
const expectedCode int = 215 |
|||
|
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/failure", func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.WriteHeader(400) |
|||
fmt.Fprintf(w, `{"message": "Invalid argument", "code": 215}`) |
|||
}) |
|||
|
|||
sling := New().Client(client) |
|||
req, _ := http.NewRequest("GET", "http://example.com/failure", nil) |
|||
|
|||
model := new(FakeModel) |
|||
apiError := new(APIError) |
|||
resp, err := sling.Do(req, model, apiError) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 400 { |
|||
t.Errorf("expected %d, got %d", 400, resp.StatusCode) |
|||
} |
|||
if apiError.Message != expectedMessage { |
|||
t.Errorf("expected %s, got %s", expectedMessage, apiError.Message) |
|||
} |
|||
if apiError.Code != expectedCode { |
|||
t.Errorf("expected %d, got %d", expectedCode, apiError.Code) |
|||
} |
|||
} |
|||
|
|||
func TestDo_onFailureWithNilValue(t *testing.T) { |
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/failure", func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.WriteHeader(420) |
|||
fmt.Fprintf(w, `{"message": "Enhance your calm", "code": 88}`) |
|||
}) |
|||
|
|||
sling := New().Client(client) |
|||
req, _ := http.NewRequest("GET", "http://example.com/failure", nil) |
|||
|
|||
model := new(FakeModel) |
|||
resp, err := sling.Do(req, model, nil) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 420 { |
|||
t.Errorf("expected %d, got %d", 420, resp.StatusCode) |
|||
} |
|||
expected := &FakeModel{} |
|||
if !reflect.DeepEqual(expected, model) { |
|||
t.Errorf("successV should not be populated, exepcted %v, got %v", expected, model) |
|||
} |
|||
} |
|||
|
|||
func TestDo_skipDecodingIfContentTypeWrong(t *testing.T) { |
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) { |
|||
w.Header().Set("Content-Type", "text/html") |
|||
fmt.Fprintf(w, `{"text": "Some text", "favorite_count": 24}`) |
|||
}) |
|||
|
|||
sling := New().Client(client) |
|||
req, _ := http.NewRequest("GET", "http://example.com/success", nil) |
|||
|
|||
model := new(FakeModel) |
|||
sling.Do(req, model, nil) |
|||
|
|||
expectedModel := &FakeModel{} |
|||
if !reflect.DeepEqual(expectedModel, model) { |
|||
t.Errorf("decoding should have been skipped, Content-Type was incorrect") |
|||
} |
|||
} |
|||
|
|||
func TestReceive_success(t *testing.T) { |
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/foo/submit", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{"kind_name": "vanilla", "count": "11"}, r) |
|||
assertPostForm(t, map[string]string{"kind_name": "vanilla", "count": "11"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
fmt.Fprintf(w, `{"text": "Some text", "favorite_count": 24}`) |
|||
}) |
|||
|
|||
endpoint := New().Client(client).Base("http://example.com/").Path("foo/").Post("submit") |
|||
// encode url-tagged struct in query params and as post body for testing purposes
|
|||
params := FakeParams{KindName: "vanilla", Count: 11} |
|||
model := new(FakeModel) |
|||
apiError := new(APIError) |
|||
resp, err := endpoint.New().QueryStruct(params).BodyForm(params).Receive(model, apiError) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 200 { |
|||
t.Errorf("expected %d, got %d", 200, resp.StatusCode) |
|||
} |
|||
expectedModel := &FakeModel{Text: "Some text", FavoriteCount: 24} |
|||
if !reflect.DeepEqual(expectedModel, model) { |
|||
t.Errorf("expected %v, got %v", expectedModel, model) |
|||
} |
|||
expectedAPIError := &APIError{} |
|||
if !reflect.DeepEqual(expectedAPIError, apiError) { |
|||
t.Errorf("failureV should be zero valued, exepcted %v, got %v", expectedAPIError, apiError) |
|||
} |
|||
} |
|||
|
|||
func TestReceive_failure(t *testing.T) { |
|||
client, mux, server := testServer() |
|||
defer server.Close() |
|||
mux.HandleFunc("/foo/submit", func(w http.ResponseWriter, r *http.Request) { |
|||
assertMethod(t, "POST", r) |
|||
assertQuery(t, map[string]string{"kind_name": "vanilla", "count": "11"}, r) |
|||
assertPostForm(t, map[string]string{"kind_name": "vanilla", "count": "11"}, r) |
|||
w.Header().Set("Content-Type", "application/json") |
|||
w.WriteHeader(429) |
|||
fmt.Fprintf(w, `{"message": "Rate limit exceeded", "code": 88}`) |
|||
}) |
|||
|
|||
endpoint := New().Client(client).Base("http://example.com/").Path("foo/").Post("submit") |
|||
// encode url-tagged struct in query params and as post body for testing purposes
|
|||
params := FakeParams{KindName: "vanilla", Count: 11} |
|||
model := new(FakeModel) |
|||
apiError := new(APIError) |
|||
resp, err := endpoint.New().QueryStruct(params).BodyForm(params).Receive(model, apiError) |
|||
|
|||
if err != nil { |
|||
t.Errorf("expected nil, got %v", err) |
|||
} |
|||
if resp.StatusCode != 429 { |
|||
t.Errorf("expected %d, got %d", 429, resp.StatusCode) |
|||
} |
|||
expectedAPIError := &APIError{Message: "Rate limit exceeded", Code: 88} |
|||
if !reflect.DeepEqual(expectedAPIError, apiError) { |
|||
t.Errorf("expected %v, got %v", expectedAPIError, apiError) |
|||
} |
|||
expectedModel := &FakeModel{} |
|||
if !reflect.DeepEqual(expectedModel, model) { |
|||
t.Errorf("successV should not be zero valued, expected %v, got %v", expectedModel, model) |
|||
} |
|||
} |
|||
|
|||
func TestReceive_errorCreatingRequest(t *testing.T) { |
|||
expectedErr := errors.New("json: unsupported value: +Inf") |
|||
resp, err := New().BodyJSON(FakeModel{Temperature: math.Inf(1)}).Receive(nil, nil) |
|||
if err == nil || err.Error() != expectedErr.Error() { |
|||
t.Errorf("expected %v, got %v", expectedErr, err) |
|||
} |
|||
if resp != nil { |
|||
t.Errorf("expected nil resp, got %v", resp) |
|||
} |
|||
} |
|||
|
|||
// Testing Utils
|
|||
|
|||
// testServer returns an http Client, ServeMux, and Server. The client proxies
|
|||
// requests to the server and handlers can be registered on the mux to handle
|
|||
// requests. The caller must close the test server.
|
|||
func testServer() (*http.Client, *http.ServeMux, *httptest.Server) { |
|||
mux := http.NewServeMux() |
|||
server := httptest.NewServer(mux) |
|||
transport := &http.Transport{ |
|||
Proxy: func(req *http.Request) (*url.URL, error) { |
|||
return url.Parse(server.URL) |
|||
}, |
|||
} |
|||
client := &http.Client{Transport: transport} |
|||
return client, mux, server |
|||
} |
|||
|
|||
func assertMethod(t *testing.T, expectedMethod string, req *http.Request) { |
|||
if actualMethod := req.Method; actualMethod != expectedMethod { |
|||
t.Errorf("expected method %s, got %s", expectedMethod, actualMethod) |
|||
} |
|||
} |
|||
|
|||
// assertQuery tests that the Request has the expected url query key/val pairs
|
|||
func assertQuery(t *testing.T, expected map[string]string, req *http.Request) { |
|||
queryValues := req.URL.Query() // net/url Values is a map[string][]string
|
|||
expectedValues := url.Values{} |
|||
for key, value := range expected { |
|||
expectedValues.Add(key, value) |
|||
} |
|||
if !reflect.DeepEqual(expectedValues, queryValues) { |
|||
t.Errorf("expected parameters %v, got %v", expected, req.URL.RawQuery) |
|||
} |
|||
} |
|||
|
|||
// assertPostForm tests that the Request has the expected key values pairs url
|
|||
// encoded in its Body
|
|||
func assertPostForm(t *testing.T, expected map[string]string, req *http.Request) { |
|||
req.ParseForm() // parses request Body to put url.Values in r.Form/r.PostForm
|
|||
expectedValues := url.Values{} |
|||
for key, value := range expected { |
|||
expectedValues.Add(key, value) |
|||
} |
|||
if !reflect.DeepEqual(expectedValues, req.PostForm) { |
|||
t.Errorf("expected parameters %v, got %v", expected, req.PostForm) |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue