mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
32354261d3
Extracts part of the prometheus middleware that stores the route information in the request context into its own middleware. Also adds request method information to context. Relates to https://github.com/coder/internal/issues/1214
105 lines
3.0 KiB
Go
105 lines
3.0 KiB
Go
package httpmw_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestHTTPRoute(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, tc := range []struct {
|
|
name string
|
|
reqFn func() *http.Request
|
|
registerRoutes map[string]string
|
|
mws []func(http.Handler) http.Handler
|
|
expectedRoute string
|
|
expectedMethod string
|
|
}{
|
|
{
|
|
name: "without middleware",
|
|
reqFn: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/", nil)
|
|
},
|
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
|
mws: []func(http.Handler) http.Handler{},
|
|
expectedRoute: "",
|
|
expectedMethod: "",
|
|
},
|
|
{
|
|
name: "root",
|
|
reqFn: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/", nil)
|
|
},
|
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
|
expectedRoute: "/",
|
|
expectedMethod: http.MethodGet,
|
|
},
|
|
{
|
|
name: "parameterized route",
|
|
reqFn: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
|
|
},
|
|
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
|
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
|
expectedRoute: "/users/{id}",
|
|
expectedMethod: http.MethodPut,
|
|
},
|
|
{
|
|
name: "unknown",
|
|
reqFn: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
|
|
},
|
|
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
|
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
|
expectedRoute: "UNKNOWN",
|
|
expectedMethod: http.MethodGet,
|
|
},
|
|
{
|
|
name: "static",
|
|
reqFn: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
|
|
},
|
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
|
expectedRoute: "STATIC",
|
|
expectedMethod: http.MethodGet,
|
|
},
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
r := chi.NewRouter()
|
|
done := make(chan string)
|
|
for _, mw := range tc.mws {
|
|
r.Use(mw)
|
|
}
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer close(done)
|
|
method := httpmw.ExtractHTTPMethod(r.Context())
|
|
route := httpmw.ExtractHTTPRoute(r.Context())
|
|
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
|
|
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
})
|
|
for method, route := range tc.registerRoutes {
|
|
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
|
|
}
|
|
req := tc.reqFn()
|
|
r.ServeHTTP(httptest.NewRecorder(), req)
|
|
_ = testutil.TryReceive(ctx, t, done)
|
|
})
|
|
}
|
|
}
|