mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore(coderd/httpmw): extract HTTPRoute middleware (#21498)
Extracts part of the prometheus middleware that stores the route information in the request context into its own middleware. Also adds request method information to context. Relates to https://github.com/coder/internal/issues/1214
This commit is contained in:
@@ -881,6 +881,7 @@ func New(options *Options) *API {
|
|||||||
loggermw.Logger(api.Logger),
|
loggermw.Logger(api.Logger),
|
||||||
singleSlashMW,
|
singleSlashMW,
|
||||||
rolestore.CustomRoleMW,
|
rolestore.CustomRoleMW,
|
||||||
|
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||||
prometheusMW,
|
prometheusMW,
|
||||||
// Build-Version is helpful for debugging.
|
// Build-Version is helpful for debugging.
|
||||||
func(next http.Handler) http.Handler {
|
func(next http.Handler) http.Handler {
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package httpmw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
httpRouteInfoKey struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type httpRouteInfo struct {
|
||||||
|
Route string
|
||||||
|
Method string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractHTTPRoute retrieves just the HTTP route pattern from context.
|
||||||
|
// Returns empty string if not set.
|
||||||
|
func ExtractHTTPRoute(ctx context.Context) string {
|
||||||
|
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||||
|
return ri.Route
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractHTTPMethod retrieves just the HTTP method from context.
|
||||||
|
// Returns empty string if not set.
|
||||||
|
func ExtractHTTPMethod(ctx context.Context) string {
|
||||||
|
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||||
|
return ri.Method
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPRoute is middleware that stores the HTTP route pattern and method in
|
||||||
|
// context for use by downstream handlers and services (e.g. prometheus).
|
||||||
|
func HTTPRoute(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
route := getRoutePattern(r)
|
||||||
|
ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{
|
||||||
|
Route: route,
|
||||||
|
Method: r.Method,
|
||||||
|
})
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRoutePattern(r *http.Request) string {
|
||||||
|
rctx := chi.RouteContext(r.Context())
|
||||||
|
if rctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
routePath := r.URL.Path
|
||||||
|
if r.URL.RawPath != "" {
|
||||||
|
routePath = r.URL.RawPath
|
||||||
|
}
|
||||||
|
|
||||||
|
tctx := chi.NewRouteContext()
|
||||||
|
routes := rctx.Routes
|
||||||
|
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||||
|
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||||
|
// All other ones will be matched as "STATIC".
|
||||||
|
if strings.HasPrefix(routePath, "/api/") {
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
return "STATIC"
|
||||||
|
}
|
||||||
|
|
||||||
|
// tctx has the updated pattern, since Match mutates it
|
||||||
|
return tctx.RoutePattern()
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package httpmw_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/httpmw"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPRoute(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
reqFn func() *http.Request
|
||||||
|
registerRoutes map[string]string
|
||||||
|
mws []func(http.Handler) http.Handler
|
||||||
|
expectedRoute string
|
||||||
|
expectedMethod string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "without middleware",
|
||||||
|
reqFn: func() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
},
|
||||||
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||||
|
mws: []func(http.Handler) http.Handler{},
|
||||||
|
expectedRoute: "",
|
||||||
|
expectedMethod: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root",
|
||||||
|
reqFn: func() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
},
|
||||||
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||||
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||||
|
expectedRoute: "/",
|
||||||
|
expectedMethod: http.MethodGet,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parameterized route",
|
||||||
|
reqFn: func() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
|
||||||
|
},
|
||||||
|
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
|
||||||
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||||
|
expectedRoute: "/users/{id}",
|
||||||
|
expectedMethod: http.MethodPut,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown",
|
||||||
|
reqFn: func() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
|
||||||
|
},
|
||||||
|
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
|
||||||
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||||
|
expectedRoute: "UNKNOWN",
|
||||||
|
expectedMethod: http.MethodGet,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "static",
|
||||||
|
reqFn: func() *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
|
||||||
|
},
|
||||||
|
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||||
|
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||||
|
expectedRoute: "STATIC",
|
||||||
|
expectedMethod: http.MethodGet,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
r := chi.NewRouter()
|
||||||
|
done := make(chan string)
|
||||||
|
for _, mw := range tc.mws {
|
||||||
|
r.Use(mw)
|
||||||
|
}
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer close(done)
|
||||||
|
method := httpmw.ExtractHTTPMethod(r.Context())
|
||||||
|
route := httpmw.ExtractHTTPRoute(r.Context())
|
||||||
|
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
|
||||||
|
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
for method, route := range tc.registerRoutes {
|
||||||
|
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
}
|
||||||
|
req := tc.reqFn()
|
||||||
|
r.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
_ = testutil.TryReceive(ctx, t, done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,10 +3,8 @@ package httpmw
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
|||||||
var (
|
var (
|
||||||
dist *prometheus.HistogramVec
|
dist *prometheus.HistogramVec
|
||||||
distOpts []string
|
distOpts []string
|
||||||
path = getRoutePattern(r)
|
path = ExtractHTTPRoute(r.Context())
|
||||||
)
|
)
|
||||||
|
|
||||||
// We want to count WebSockets separately.
|
// We want to count WebSockets separately.
|
||||||
@@ -98,29 +96,3 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRoutePattern(r *http.Request) string {
|
|
||||||
rctx := chi.RouteContext(r.Context())
|
|
||||||
if rctx == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
routePath := r.URL.Path
|
|
||||||
if r.URL.RawPath != "" {
|
|
||||||
routePath = r.URL.RawPath
|
|
||||||
}
|
|
||||||
|
|
||||||
tctx := chi.NewRouteContext()
|
|
||||||
routes := rctx.Routes
|
|
||||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
|
||||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
|
||||||
// All other ones will be matched as "STATIC".
|
|
||||||
if strings.HasPrefix(routePath, "/api/") {
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
return "STATIC"
|
|
||||||
}
|
|
||||||
|
|
||||||
// tctx has the updated pattern, since Match mutates it
|
|
||||||
return tctx.RoutePattern()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) {
|
|||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||||
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||||
reg := prometheus.NewRegistry()
|
reg := prometheus.NewRegistry()
|
||||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})).ServeHTTP(res, req)
|
}))).ServeHTTP(res, req)
|
||||||
metrics, err := reg.Gather()
|
metrics, err := reg.Gather()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Greater(t, len(metrics), 0)
|
require.Greater(t, len(metrics), 0)
|
||||||
@@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) {
|
|||||||
wrappedHandler := promMW(testHandler)
|
wrappedHandler := promMW(testHandler)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(tracing.StatusWriterMiddleware, promMW)
|
r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW)
|
||||||
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
|
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
wrappedHandler.ServeHTTP(rw, r)
|
wrappedHandler.ServeHTTP(rw, r)
|
||||||
})
|
})
|
||||||
@@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) {
|
|||||||
promMW := httpmw.Prometheus(reg)
|
promMW := httpmw.Prometheus(reg)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
|
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
|
||||||
|
|
||||||
@@ -115,6 +115,7 @@ func TestPrometheus(t *testing.T) {
|
|||||||
promMW := httpmw.Prometheus(reg)
|
promMW := httpmw.Prometheus(reg)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
r.Use(httpmw.HTTPRoute)
|
||||||
r.Use(promMW)
|
r.Use(promMW)
|
||||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
@@ -145,6 +146,7 @@ func TestPrometheus(t *testing.T) {
|
|||||||
promMW := httpmw.Prometheus(reg)
|
promMW := httpmw.Prometheus(reg)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
r.Use(httpmw.HTTPRoute)
|
||||||
r.Use(promMW)
|
r.Use(promMW)
|
||||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
@@ -173,6 +175,7 @@ func TestPrometheus(t *testing.T) {
|
|||||||
promMW := httpmw.Prometheus(reg)
|
promMW := httpmw.Prometheus(reg)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
r.Use(httpmw.HTTPRoute)
|
||||||
r.Use(promMW)
|
r.Use(promMW)
|
||||||
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})
|
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user