chore(coderd/httpmw): extract HTTPRoute middleware (#21498)

Extracts part of the prometheus middleware that stores the route
information in the request context into its own middleware. Also adds
request method information to context.

Relates to https://github.com/coder/internal/issues/1214
This commit is contained in:
Cian Johnston
2026-01-15 10:26:50 +00:00
committed by GitHub
parent 6683d807ac
commit 32354261d3
5 changed files with 184 additions and 33 deletions
+1
View File
@@ -881,6 +881,7 @@ func New(options *Options) *API {
loggermw.Logger(api.Logger), loggermw.Logger(api.Logger),
singleSlashMW, singleSlashMW,
rolestore.CustomRoleMW, rolestore.CustomRoleMW,
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
prometheusMW, prometheusMW,
// Build-Version is helpful for debugging. // Build-Version is helpful for debugging.
func(next http.Handler) http.Handler { func(next http.Handler) http.Handler {
+71
View File
@@ -0,0 +1,71 @@
package httpmw
import (
"context"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
)
type (
httpRouteInfoKey struct{}
)
type httpRouteInfo struct {
Route string
Method string
}
// ExtractHTTPRoute retrieves just the HTTP route pattern from context.
// Returns empty string if not set.
func ExtractHTTPRoute(ctx context.Context) string {
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
return ri.Route
}
// ExtractHTTPMethod retrieves just the HTTP method from context.
// Returns empty string if not set.
func ExtractHTTPMethod(ctx context.Context) string {
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
return ri.Method
}
// HTTPRoute is middleware that stores the HTTP route pattern and method in
// context for use by downstream handlers and services (e.g. prometheus).
func HTTPRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
route := getRoutePattern(r)
ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{
Route: route,
Method: r.Method,
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func getRoutePattern(r *http.Request) string {
rctx := chi.RouteContext(r.Context())
if rctx == nil {
return ""
}
routePath := r.URL.Path
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
}
tctx := chi.NewRouteContext()
routes := rctx.Routes
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
// All other ones will be matched as "STATIC".
if strings.HasPrefix(routePath, "/api/") {
return "UNKNOWN"
}
return "STATIC"
}
// tctx has the updated pattern, since Match mutates it
return tctx.RoutePattern()
}
+104
View File
@@ -0,0 +1,104 @@
package httpmw_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/testutil"
)
func TestHTTPRoute(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
reqFn func() *http.Request
registerRoutes map[string]string
mws []func(http.Handler) http.Handler
expectedRoute string
expectedMethod string
}{
{
name: "without middleware",
reqFn: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/", nil)
},
registerRoutes: map[string]string{http.MethodGet: "/"},
mws: []func(http.Handler) http.Handler{},
expectedRoute: "",
expectedMethod: "",
},
{
name: "root",
reqFn: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/", nil)
},
registerRoutes: map[string]string{http.MethodGet: "/"},
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
expectedRoute: "/",
expectedMethod: http.MethodGet,
},
{
name: "parameterized route",
reqFn: func() *http.Request {
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
},
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
expectedRoute: "/users/{id}",
expectedMethod: http.MethodPut,
},
{
name: "unknown",
reqFn: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
},
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
expectedRoute: "UNKNOWN",
expectedMethod: http.MethodGet,
},
{
name: "static",
reqFn: func() *http.Request {
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
},
registerRoutes: map[string]string{http.MethodGet: "/"},
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
expectedRoute: "STATIC",
expectedMethod: http.MethodGet,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
r := chi.NewRouter()
done := make(chan string)
for _, mw := range tc.mws {
r.Use(mw)
}
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(done)
method := httpmw.ExtractHTTPMethod(r.Context())
route := httpmw.ExtractHTTPRoute(r.Context())
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
next.ServeHTTP(w, r)
})
})
for method, route := range tc.registerRoutes {
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
}
req := tc.reqFn()
r.ServeHTTP(httptest.NewRecorder(), req)
_ = testutil.TryReceive(ctx, t, done)
})
}
}
+1 -29
View File
@@ -3,10 +3,8 @@ package httpmw
import ( import (
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@@ -71,7 +69,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
var ( var (
dist *prometheus.HistogramVec dist *prometheus.HistogramVec
distOpts []string distOpts []string
path = getRoutePattern(r) path = ExtractHTTPRoute(r.Context())
) )
// We want to count WebSockets separately. // We want to count WebSockets separately.
@@ -98,29 +96,3 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
}) })
} }
} }
func getRoutePattern(r *http.Request) string {
rctx := chi.RouteContext(r.Context())
if rctx == nil {
return ""
}
routePath := r.URL.Path
if r.URL.RawPath != "" {
routePath = r.URL.RawPath
}
tctx := chi.NewRouteContext()
routes := rctx.Routes
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
// All other ones will be matched as "STATIC".
if strings.HasPrefix(routePath, "/api/") {
return "UNKNOWN"
}
return "STATIC"
}
// tctx has the updated pattern, since Match mutates it
return tctx.RoutePattern()
}
+7 -4
View File
@@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) {
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext())) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()} res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
reg := prometheus.NewRegistry() reg := prometheus.NewRegistry()
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, req) }))).ServeHTTP(res, req)
metrics, err := reg.Gather() metrics, err := reg.Gather()
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, len(metrics), 0) require.Greater(t, len(metrics), 0)
@@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) {
wrappedHandler := promMW(testHandler) wrappedHandler := promMW(testHandler)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(tracing.StatusWriterMiddleware, promMW) r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW)
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) { r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
wrappedHandler.ServeHTTP(rw, r) wrappedHandler.ServeHTTP(rw, r)
}) })
@@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) {
promMW := httpmw.Prometheus(reg) promMW := httpmw.Prometheus(reg)
r := chi.NewRouter() r := chi.NewRouter()
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {}) r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
req := httptest.NewRequest("GET", "/api/v2/users/john", nil) req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
@@ -115,6 +115,7 @@ func TestPrometheus(t *testing.T) {
promMW := httpmw.Prometheus(reg) promMW := httpmw.Prometheus(reg)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(httpmw.HTTPRoute)
r.Use(promMW) r.Use(promMW)
r.NotFound(func(w http.ResponseWriter, r *http.Request) { r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
@@ -145,6 +146,7 @@ func TestPrometheus(t *testing.T) {
promMW := httpmw.Prometheus(reg) promMW := httpmw.Prometheus(reg)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(httpmw.HTTPRoute)
r.Use(promMW) r.Use(promMW)
r.NotFound(func(w http.ResponseWriter, r *http.Request) { r.NotFound(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
@@ -173,6 +175,7 @@ func TestPrometheus(t *testing.T) {
promMW := httpmw.Prometheus(reg) promMW := httpmw.Prometheus(reg)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(httpmw.HTTPRoute)
r.Use(promMW) r.Use(promMW)
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {}) r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})