From 32354261d3206ce6c5b79daac60f90d9c4401e4d Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 15 Jan 2026 10:26:50 +0000 Subject: [PATCH] chore(coderd/httpmw): extract HTTPRoute middleware (#21498) Extracts part of the prometheus middleware that stores the route information in the request context into its own middleware. Also adds request method information to context. Relates to https://github.com/coder/internal/issues/1214 --- coderd/coderd.go | 1 + coderd/httpmw/httproute.go | 71 +++++++++++++++++++++ coderd/httpmw/httproute_test.go | 104 +++++++++++++++++++++++++++++++ coderd/httpmw/prometheus.go | 30 +-------- coderd/httpmw/prometheus_test.go | 11 ++-- 5 files changed, 184 insertions(+), 33 deletions(-) create mode 100644 coderd/httpmw/httproute.go create mode 100644 coderd/httpmw/httproute_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 351ddde805..48483234a3 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -881,6 +881,7 @@ func New(options *Options) *API { loggermw.Logger(api.Logger), singleSlashMW, rolestore.CustomRoleMW, + httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware. prometheusMW, // Build-Version is helpful for debugging. func(next http.Handler) http.Handler { diff --git a/coderd/httpmw/httproute.go b/coderd/httpmw/httproute.go new file mode 100644 index 0000000000..373835274d --- /dev/null +++ b/coderd/httpmw/httproute.go @@ -0,0 +1,71 @@ +package httpmw + +import ( + "context" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +type ( + httpRouteInfoKey struct{} +) + +type httpRouteInfo struct { + Route string + Method string +} + +// ExtractHTTPRoute retrieves just the HTTP route pattern from context. +// Returns empty string if not set. +func ExtractHTTPRoute(ctx context.Context) string { + ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo) + return ri.Route +} + +// ExtractHTTPMethod retrieves just the HTTP method from context. +// Returns empty string if not set. +func ExtractHTTPMethod(ctx context.Context) string { + ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo) + return ri.Method +} + +// HTTPRoute is middleware that stores the HTTP route pattern and method in +// context for use by downstream handlers and services (e.g. prometheus). +func HTTPRoute(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := getRoutePattern(r) + ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{ + Route: route, + Method: r.Method, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func getRoutePattern(r *http.Request) string { + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + return "" + } + + routePath := r.URL.Path + if r.URL.RawPath != "" { + routePath = r.URL.RawPath + } + + tctx := chi.NewRouteContext() + routes := rctx.Routes + if routes != nil && !routes.Match(tctx, r.Method, routePath) { + // No matching pattern. /api/* requests will be matched as "UNKNOWN" + // All other ones will be matched as "STATIC". + if strings.HasPrefix(routePath, "/api/") { + return "UNKNOWN" + } + return "STATIC" + } + + // tctx has the updated pattern, since Match mutates it + return tctx.RoutePattern() +} diff --git a/coderd/httpmw/httproute_test.go b/coderd/httpmw/httproute_test.go new file mode 100644 index 0000000000..8c908df47f --- /dev/null +++ b/coderd/httpmw/httproute_test.go @@ -0,0 +1,104 @@ +package httpmw_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/testutil" +) + +func TestHTTPRoute(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + reqFn func() *http.Request + registerRoutes map[string]string + mws []func(http.Handler) http.Handler + expectedRoute string + expectedMethod string + }{ + { + name: "without middleware", + reqFn: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/", nil) + }, + registerRoutes: map[string]string{http.MethodGet: "/"}, + mws: []func(http.Handler) http.Handler{}, + expectedRoute: "", + expectedMethod: "", + }, + { + name: "root", + reqFn: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/", nil) + }, + registerRoutes: map[string]string{http.MethodGet: "/"}, + mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute}, + expectedRoute: "/", + expectedMethod: http.MethodGet, + }, + { + name: "parameterized route", + reqFn: func() *http.Request { + return httptest.NewRequest(http.MethodPut, "/users/123", nil) + }, + registerRoutes: map[string]string{http.MethodPut: "/users/{id}"}, + mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute}, + expectedRoute: "/users/{id}", + expectedMethod: http.MethodPut, + }, + { + name: "unknown", + reqFn: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/api/a", nil) + }, + registerRoutes: map[string]string{http.MethodGet: "/api/b"}, + mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute}, + expectedRoute: "UNKNOWN", + expectedMethod: http.MethodGet, + }, + { + name: "static", + reqFn: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil) + }, + registerRoutes: map[string]string{http.MethodGet: "/"}, + mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute}, + expectedRoute: "STATIC", + expectedMethod: http.MethodGet, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + r := chi.NewRouter() + done := make(chan string) + for _, mw := range tc.mws { + r.Use(mw) + } + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(done) + method := httpmw.ExtractHTTPMethod(r.Context()) + route := httpmw.ExtractHTTPRoute(r.Context()) + assert.Equal(t, tc.expectedMethod, method, "expected method mismatch") + assert.Equal(t, tc.expectedRoute, route, "expected route mismatch") + next.ServeHTTP(w, r) + }) + }) + for method, route := range tc.registerRoutes { + r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {}) + } + req := tc.reqFn() + r.ServeHTTP(httptest.NewRecorder(), req) + _ = testutil.TryReceive(ctx, t, done) + }) + } +} diff --git a/coderd/httpmw/prometheus.go b/coderd/httpmw/prometheus.go index 1823edde11..246d314e13 100644 --- a/coderd/httpmw/prometheus.go +++ b/coderd/httpmw/prometheus.go @@ -3,10 +3,8 @@ package httpmw import ( "net/http" "strconv" - "strings" "time" - "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -71,7 +69,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler var ( dist *prometheus.HistogramVec distOpts []string - path = getRoutePattern(r) + path = ExtractHTTPRoute(r.Context()) ) // We want to count WebSockets separately. @@ -98,29 +96,3 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler }) } } - -func getRoutePattern(r *http.Request) string { - rctx := chi.RouteContext(r.Context()) - if rctx == nil { - return "" - } - - routePath := r.URL.Path - if r.URL.RawPath != "" { - routePath = r.URL.RawPath - } - - tctx := chi.NewRouteContext() - routes := rctx.Routes - if routes != nil && !routes.Match(tctx, r.Method, routePath) { - // No matching pattern. /api/* requests will be matched as "UNKNOWN" - // All other ones will be matched as "STATIC". - if strings.HasPrefix(routePath, "/api/") { - return "UNKNOWN" - } - return "STATIC" - } - - // tctx has the updated pattern, since Match mutates it - return tctx.RoutePattern() -} diff --git a/coderd/httpmw/prometheus_test.go b/coderd/httpmw/prometheus_test.go index 87928259e9..5446e9bad8 100644 --- a/coderd/httpmw/prometheus_test.go +++ b/coderd/httpmw/prometheus_test.go @@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) { req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext())) res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} reg := prometheus.NewRegistry() - httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - })).ServeHTTP(res, req) + }))).ServeHTTP(res, req) metrics, err := reg.Gather() require.NoError(t, err) require.Greater(t, len(metrics), 0) @@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) { wrappedHandler := promMW(testHandler) r := chi.NewRouter() - r.Use(tracing.StatusWriterMiddleware, promMW) + r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW) r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) { wrappedHandler.ServeHTTP(rw, r) }) @@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) { promMW := httpmw.Prometheus(reg) r := chi.NewRouter() - r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {}) + r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {}) req := httptest.NewRequest("GET", "/api/v2/users/john", nil) @@ -115,6 +115,7 @@ func TestPrometheus(t *testing.T) { promMW := httpmw.Prometheus(reg) r := chi.NewRouter() + r.Use(httpmw.HTTPRoute) r.Use(promMW) r.NotFound(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) @@ -145,6 +146,7 @@ func TestPrometheus(t *testing.T) { promMW := httpmw.Prometheus(reg) r := chi.NewRouter() + r.Use(httpmw.HTTPRoute) r.Use(promMW) r.NotFound(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) @@ -173,6 +175,7 @@ func TestPrometheus(t *testing.T) { promMW := httpmw.Prometheus(reg) r := chi.NewRouter() + r.Use(httpmw.HTTPRoute) r.Use(promMW) r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})