mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore(coderd/httpmw): extract HTTPRoute middleware (#21498)
Extracts part of the prometheus middleware that stores the route information in the request context into its own middleware. Also adds request method information to context. Relates to https://github.com/coder/internal/issues/1214
This commit is contained in:
@@ -881,6 +881,7 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
func(next http.Handler) http.Handler {
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type (
|
||||
httpRouteInfoKey struct{}
|
||||
)
|
||||
|
||||
type httpRouteInfo struct {
|
||||
Route string
|
||||
Method string
|
||||
}
|
||||
|
||||
// ExtractHTTPRoute retrieves just the HTTP route pattern from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPRoute(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Route
|
||||
}
|
||||
|
||||
// ExtractHTTPMethod retrieves just the HTTP method from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPMethod(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Method
|
||||
}
|
||||
|
||||
// HTTPRoute is middleware that stores the HTTP route pattern and method in
|
||||
// context for use by downstream handlers and services (e.g. prometheus).
|
||||
func HTTPRoute(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
route := getRoutePattern(r)
|
||||
ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{
|
||||
Route: route,
|
||||
Method: r.Method,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestHTTPRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
reqFn func() *http.Request
|
||||
registerRoutes map[string]string
|
||||
mws []func(http.Handler) http.Handler
|
||||
expectedRoute string
|
||||
expectedMethod string
|
||||
}{
|
||||
{
|
||||
name: "without middleware",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{},
|
||||
expectedRoute: "",
|
||||
expectedMethod: "",
|
||||
},
|
||||
{
|
||||
name: "root",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "parameterized route",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/users/{id}",
|
||||
expectedMethod: http.MethodPut,
|
||||
},
|
||||
{
|
||||
name: "unknown",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "UNKNOWN",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "static",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "STATIC",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
r := chi.NewRouter()
|
||||
done := make(chan string)
|
||||
for _, mw := range tc.mws {
|
||||
r.Use(mw)
|
||||
}
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
method := httpmw.ExtractHTTPMethod(r.Context())
|
||||
route := httpmw.ExtractHTTPRoute(r.Context())
|
||||
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
|
||||
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
for method, route := range tc.registerRoutes {
|
||||
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
req := tc.reqFn()
|
||||
r.ServeHTTP(httptest.NewRecorder(), req)
|
||||
_ = testutil.TryReceive(ctx, t, done)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,8 @@ package httpmw
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
@@ -71,7 +69,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
var (
|
||||
dist *prometheus.HistogramVec
|
||||
distOpts []string
|
||||
path = getRoutePattern(r)
|
||||
path = ExtractHTTPRoute(r.Context())
|
||||
)
|
||||
|
||||
// We want to count WebSockets separately.
|
||||
@@ -98,29 +96,3 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) {
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
reg := prometheus.NewRegistry()
|
||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, req)
|
||||
}))).ServeHTTP(res, req)
|
||||
metrics, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(metrics), 0)
|
||||
@@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) {
|
||||
wrappedHandler := promMW(testHandler)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(tracing.StatusWriterMiddleware, promMW)
|
||||
r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW)
|
||||
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
|
||||
wrappedHandler.ServeHTTP(rw, r)
|
||||
})
|
||||
@@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
|
||||
|
||||
@@ -115,6 +115,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -145,6 +146,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -173,6 +175,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user