Files
coder/coderd/coderd_test.go
T
Kacper Sawicki 49006685b0 fix: rate limit by user instead of IP for authenticated requests (#22049)
## Problem

Rate limiting by user is broken (#20857). The rate limit middleware runs
before API key extraction, so user ID is never in the request context.
This causes:
- Rate limiting falls back to IP address for all requests
- `X-Coder-Bypass-Ratelimit` header for Owners is ignored (can't verify
role without identity)

## Solution

Adds `PrecheckAPIKey`, a **root-level middleware** that fully validates
the API key on every request (expiry, OIDC refresh, DB updates, role
lookup) and stores the result in context. Added **once** at the root
router — not duplicated per route group.

### Architecture

```
Request → Root middleware stack:
  → ExtractRealIP, Logger, ...
  → PrecheckAPIKey(...)              ← validates key, stores result, never rejects
  → HandleSubdomain(apiRateLimiter)  ← workspace apps now also benefit
  → CORS, CSRF

→ /api/v2 or /api/experimental:
  → apiRateLimiter                   ← reads prechecked result from context
  → route handlers:
    → ExtractAPIKeyMW                ← reuses prechecked data, adds route-specific logic
    → handler
```

### Key design decisions

| Decision | Rationale |
|---|---|
| **Full validation, not lightweight** | Spike's review: "the whole idea
of a 'lightweight' extraction that skips security checks is
fundamentally flawed." Only fully validated keys are used for rate
limiting — expired/invalid keys fall back to IP. |
| **Structured error results** | `ValidateAPIKeyError` has a `Hard` flag
that maps to `write` vs `optionalWrite`. Hard errors (5xx, OAuth refresh
failures) surface even on optional-auth routes. Soft errors
(missing/expired token) are swallowed on optional routes. |
| **Added once at the root** | Spike's review: "Why can't we add it once
at the root?" Root placement means workspace app rate limiters also
benefit. |
| **Skip prechecked when `SessionTokenFunc != nil`** |
`workspaceapps/db.go` uses a custom `SessionTokenFunc` that extracts
from `issueReq.SessionToken`. The prechecked result may have validated a
different token. Falls back to `ValidateAPIKey` with the custom func. |
| **User status check stays in `ExtractAPIKey`** | Dormant activation is
route-specific — `ValidateAPIKey` stores status but doesn't enforce it.
|
| **Audience validation stays in `ExtractAPIKey`** | Depends on
`cfg.AccessURL` and request path, uses `optionalWrite(403)` which
depends on route config. |

### Changes

- **`coderd/httpmw/apikey.go`**:
- New `ValidateAPIKey` function — extracted core validation logic,
returns structured errors instead of writing HTTP responses
- New `PrecheckAPIKey` middleware — calls `ValidateAPIKey`, stores
result in `apiKeyPrecheckedContextKey`, never rejects
- New types: `ValidateAPIKeyConfig`, `ValidateAPIKeyResult`,
`ValidateAPIKeyError`, `APIKeyPrechecked`
- Refactored `ExtractAPIKey` — consumes prechecked result from context
(skipping redundant validation), falls back to `ValidateAPIKey` when no
precheck available
  - Removed `ExtractAPIKeyForRateLimit` and `preExtractedAPIKey`
- **`coderd/httpmw/ratelimit.go`**: Rate limiter checks
`apiKeyPrecheckedContextKey` first, then `apiKeyContextKey` fallback
(for unit tests / workspace apps), then IP
- **`coderd/coderd.go`**: Added `PrecheckAPIKey` once at root
`r.Use(...)` block, removed `ExtractAPIKeyForRateLimit` from `/api/v2`
and `/api/experimental`
- **`coderd/coderd_test.go`**: `TestRateLimitByUser` regression test
with `BypassOwner` subtest

Fixes #20857
2026-03-09 13:54:31 +01:00

507 lines
15 KiB
Go

package coderd_test
import (
"context"
"flag"
"fmt"
"io"
"net/http"
"net/netip"
"strconv"
"strings"
"sync/atomic"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"tailscale.com/tailcfg"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
// updateGoldenFiles is a flag that can be set to update golden files.
var updateGoldenFiles = flag.Bool("update", false, "Update golden files")
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
func TestBuildInfo(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
buildInfo, err := client.BuildInfo(ctx)
require.NoError(t, err)
require.Equal(t, buildinfo.ExternalURL(), buildInfo.ExternalURL, "external URL")
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
}
func TestDERP(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client := coderdtest.New(t, nil)
logger := testutil.Logger(t)
derpPort, err := strconv.Atoi(client.URL.Port())
require.NoError(t, err)
derpMap := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "cdr",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
HostName: client.URL.Hostname(),
DERPPort: derpPort,
STUNPort: -1,
ForceHTTP: true,
}},
},
},
}
w1IP := tailnet.TailscaleServicePrefix.RandomAddr()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
w1ID := uuid.New()
w1.SetNodeCallback(func(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w1ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
})
w2ID := uuid.New()
w2.SetNodeCallback(func(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w2ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
})
conn := make(chan struct{})
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
conn <- struct{}{}
nc, err := listener.Accept()
assert.NoError(t, err)
_ = nc.Close()
conn <- struct{}{}
}()
<-conn
w2.AwaitReachable(ctx, w1IP)
nc, err := w2.DialContextTCP(ctx, netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
w1.Close()
w2.Close()
}
func TestDERPForceWebSockets(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.DERP.Config.ForceWebSockets = true
dv.DERP.Config.BlockDirect = true // to ensure the test always uses DERP
// Manually create a server so we can influence the HTTP handler.
options := &coderdtest.Options{
DeploymentValues: dv,
}
setHandler, cancelFunc, serverURL, newOptions := coderdtest.NewOptions(t, options)
coderAPI := coderd.New(newOptions)
t.Cleanup(func() {
cancelFunc()
_ = coderAPI.Close()
})
// Set the HTTP handler to a custom one that ensures all /derp calls are
// WebSockets and not `Upgrade: derp`.
var upgradeCount int64
setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/derp") {
up := r.Header.Get("Upgrade")
if up != "" && up != "websocket" {
t.Errorf("expected Upgrade: websocket, got %q", up)
} else {
atomic.AddInt64(&upgradeCount, 1)
}
}
coderAPI.RootHandler.ServeHTTP(rw, r)
}))
// Start a provisioner daemon.
provisionerCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
t.Cleanup(func() {
_ = provisionerCloser.Close()
})
client := codersdk.New(serverURL)
t.Cleanup(func() {
client.HTTPClient.CloseIdleConnections()
})
wsclient := workspacesdk.New(client)
user := coderdtest.CreateFirstUser(t, client)
gen, err := wsclient.AgentConnectionInfoGeneric(context.Background())
require.NoError(t, err)
t.Log(spew.Sdump(gen))
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
_ = agenttest.New(t, client.URL, authToken)
_ = coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := wsclient.DialAgent(ctx, resources[0].Agents[0].ID,
&workspacesdk.DialAgentOptions{
Logger: testutil.Logger(t).Named("client"),
},
)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
conn.AwaitReachable(ctx)
require.GreaterOrEqual(t, atomic.LoadInt64(&upgradeCount), int64(1), "expected at least one /derp call")
}
func TestDERPLatencyCheck(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
res, err := client.Request(context.Background(), http.MethodGet, "/derp/latency-check", nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
}
func TestFastLatencyCheck(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
res, err := client.Request(context.Background(), http.MethodGet, "/latency-check", nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
}
func TestHealthz(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
res, err := client.Request(context.Background(), http.MethodGet, "/healthz", nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "OK", string(body))
}
func TestSwagger(t *testing.T) {
t.Parallel()
const swaggerEndpoint = "/swagger"
t.Run("endpoint enabled", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
SwaggerEndpoint: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
resp, err := requestWithRetries(ctx, t, client, http.MethodGet, swaggerEndpoint, nil)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
defer resp.Body.Close()
require.Contains(t, string(body), "Swagger UI")
})
t.Run("doc.json exposed", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
SwaggerEndpoint: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
resp, err := requestWithRetries(ctx, t, client, http.MethodGet, swaggerEndpoint+"/doc.json", nil)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
defer resp.Body.Close()
require.Contains(t, string(body), `"swagger": "2.0"`)
})
t.Run("endpoint disabled by default", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
resp, err := requestWithRetries(ctx, t, client, http.MethodGet, swaggerEndpoint, nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
})
t.Run("doc.json disabled by default", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
resp, err := requestWithRetries(ctx, t, client, http.MethodGet, swaggerEndpoint+"/doc.json", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
})
}
func TestCSRFExempt(t *testing.T) {
t.Parallel()
// This test build a workspace with an agent and an app. The app is not
// a real http server, so it will fail to serve requests. We just want
// to make sure the failure is not a CSRF failure, as path based
// apps should be exempt.
t.Run("PathBasedApp", func(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, nil)
first := coderdtest.CreateFirstUser(t, client)
owner, err := client.User(context.Background(), "me")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Create a workspace.
const agentSlug = "james"
const appSlug = "web"
wrk := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OwnerID: owner.ID,
OrganizationID: first.OrganizationID,
}).
WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Name = agentSlug
agents[0].Apps = []*proto.App{{
Slug: appSlug,
DisplayName: appSlug,
Subdomain: false,
Url: "/",
}}
return agents
}).
Do()
u := client.URL.JoinPath(fmt.Sprintf("/@%s/%s.%s/apps/%s", owner.Username, wrk.Workspace.Name, agentSlug, appSlug)).String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil)
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: client.SessionToken(),
Path: "/",
Domain: client.URL.String(),
})
require.NoError(t, err)
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
data, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
// A StatusBadGateway means Coderd tried to proxy to the agent and failed because the agent
// was not there. This means CSRF did not block the app request, which is what we want.
require.Equal(t, http.StatusBadGateway, resp.StatusCode, "status code 500 is CSRF failure")
require.NotContains(t, string(data), "CSRF")
})
}
func TestDERPMetrics(t *testing.T) {
t.Parallel()
_, _, api := coderdtest.NewWithAPI(t, nil)
require.NotNil(t, api.Options.DERPServer, "DERP server should be configured")
require.NotNil(t, api.Options.PrometheusRegistry, "Prometheus registry should be configured")
// The registry is created internally by coderd. Gather from it
// to verify DERP metrics were registered during startup.
metrics, err := api.Options.PrometheusRegistry.Gather()
require.NoError(t, err)
names := make(map[string]struct{})
for _, m := range metrics {
names[m.GetName()] = struct{}{}
}
assert.Contains(t, names, "coder_derp_server_connections",
"expected coder_derp_server_connections to be registered")
assert.Contains(t, names, "coder_derp_server_bytes_received_total",
"expected coder_derp_server_bytes_received_total to be registered")
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
"expected coder_derp_server_packets_dropped_reason_total to be registered")
}
// TestRateLimitByUser verifies that rate limiting keys by user ID when
// an authenticated session is present, rather than falling back to IP.
// This is a regression test for https://github.com/coder/coder/issues/20857
func TestRateLimitByUser(t *testing.T) {
t.Parallel()
const rateLimit = 5
ownerClient := coderdtest.New(t, &coderdtest.Options{
APIRateLimit: rateLimit,
})
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
t.Run("HitsLimit", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Make rateLimit requests — they should all succeed.
for i := 0; i < rateLimit; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"request %d should succeed", i+1)
}
// The next request should be rate-limited.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
"request should be rate limited")
})
t.Run("BypassOwner", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Owner with bypass header should not be rate-limited.
for i := 0; i < rateLimit+5; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"owner bypass request %d should succeed", i+1)
}
})
t.Run("MemberCannotBypass", func(t *testing.T) {
t.Parallel()
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)
// A member requesting the bypass header should be rejected
// with 428 Precondition Required — only owners may bypass.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
memberClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := memberClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
"member should not be able to bypass rate limit")
})
}