-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcontext_test.go
More file actions
146 lines (121 loc) · 3.56 KB
/
context_test.go
File metadata and controls
146 lines (121 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package dvls
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestContextCancellation(t *testing.T) {
requestReceived := make(chan struct{})
allowResponse := make(chan struct{})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(requestReceived)
<-allowResponse
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"result":1,"response":{}}`))
}))
defer server.Close()
client := &Client{
baseUri: server.URL,
client: server.Client(),
credential: credentials{
token: "test-token",
},
}
t.Run("RequestWithContext respects cancellation", func(t *testing.T) {
t.Cleanup(func() { close(allowResponse) })
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := client.RequestWithContext(ctx, server.URL+"/test", http.MethodGet, nil)
errCh <- err
}()
<-requestReceived
cancel()
err := <-errCh
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
})
}
func TestContextSucceedsWithoutCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"result":1,"response":{}}`))
}))
defer server.Close()
client := &Client{
baseUri: server.URL,
client: server.Client(),
credential: credentials{
token: "test-token",
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.RequestWithContext(ctx, server.URL+"/test", http.MethodGet, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestContextTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"result":1,"response":{}}`))
}))
defer server.Close()
client := &Client{
baseUri: server.URL,
client: server.Client(),
credential: credentials{
token: "test-token",
},
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.RequestWithContext(ctx, server.URL+"/test", http.MethodGet, nil)
if err == nil {
t.Fatal("expected error due to context timeout, got nil")
}
}
func TestContextPropagation(t *testing.T) {
type contextKey string
const testKey contextKey = "test-key"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"result":1,"response":{"id":"test-id"}}`))
}))
defer server.Close()
contextKeyReceived := false
originalTransport := http.DefaultTransport
customTransport := &contextCheckTransport{
base: originalTransport,
checkFunc: func(ctx context.Context) {
if ctx.Value(testKey) == "test-value" {
contextKeyReceived = true
}
},
}
client := &Client{
baseUri: server.URL,
client: &http.Client{Transport: customTransport},
credential: credentials{
token: "test-token",
},
}
ctx := context.WithValue(context.Background(), testKey, "test-value")
_, _ = client.RequestWithContext(ctx, server.URL+"/test", http.MethodGet, nil)
if !contextKeyReceived {
t.Fatal("context was not properly propagated through the request")
}
}
type contextCheckTransport struct {
base http.RoundTripper
checkFunc func(context.Context)
}
func (t *contextCheckTransport) RoundTrip(req *http.Request) (*http.Response, error) {
t.checkFunc(req.Context())
return t.base.RoundTrip(req)
}