fix: rate limit by user instead of IP for authenticated requests (#22049)

## Problem

Rate limiting by user is broken (#20857). The rate limit middleware runs
before API key extraction, so user ID is never in the request context.
This causes:
- Rate limiting falls back to IP address for all requests
- `X-Coder-Bypass-Ratelimit` header for Owners is ignored (can't verify
role without identity)

## Solution

Adds `PrecheckAPIKey`, a **root-level middleware** that fully validates
the API key on every request (expiry, OIDC refresh, DB updates, role
lookup) and stores the result in context. Added **once** at the root
router — not duplicated per route group.

### Architecture

```
Request → Root middleware stack:
  → ExtractRealIP, Logger, ...
  → PrecheckAPIKey(...)              ← validates key, stores result, never rejects
  → HandleSubdomain(apiRateLimiter)  ← workspace apps now also benefit
  → CORS, CSRF

→ /api/v2 or /api/experimental:
  → apiRateLimiter                   ← reads prechecked result from context
  → route handlers:
    → ExtractAPIKeyMW                ← reuses prechecked data, adds route-specific logic
    → handler
```

### Key design decisions

| Decision | Rationale |
|---|---|
| **Full validation, not lightweight** | Spike's review: "the whole idea
of a 'lightweight' extraction that skips security checks is
fundamentally flawed." Only fully validated keys are used for rate
limiting — expired/invalid keys fall back to IP. |
| **Structured error results** | `ValidateAPIKeyError` has a `Hard` flag
that maps to `write` vs `optionalWrite`. Hard errors (5xx, OAuth refresh
failures) surface even on optional-auth routes. Soft errors
(missing/expired token) are swallowed on optional routes. |
| **Added once at the root** | Spike's review: "Why can't we add it once
at the root?" Root placement means workspace app rate limiters also
benefit. |
| **Skip prechecked when `SessionTokenFunc != nil`** |
`workspaceapps/db.go` uses a custom `SessionTokenFunc` that extracts
from `issueReq.SessionToken`. The prechecked result may have validated a
different token. Falls back to `ValidateAPIKey` with the custom func. |
| **User status check stays in `ExtractAPIKey`** | Dormant activation is
route-specific — `ValidateAPIKey` stores status but doesn't enforce it.
|
| **Audience validation stays in `ExtractAPIKey`** | Depends on
`cfg.AccessURL` and request path, uses `optionalWrite(403)` which
depends on route config. |

### Changes

- **`coderd/httpmw/apikey.go`**:
- New `ValidateAPIKey` function — extracted core validation logic,
returns structured errors instead of writing HTTP responses
- New `PrecheckAPIKey` middleware — calls `ValidateAPIKey`, stores
result in `apiKeyPrecheckedContextKey`, never rejects
- New types: `ValidateAPIKeyConfig`, `ValidateAPIKeyResult`,
`ValidateAPIKeyError`, `APIKeyPrechecked`
- Refactored `ExtractAPIKey` — consumes prechecked result from context
(skipping redundant validation), falls back to `ValidateAPIKey` when no
precheck available
  - Removed `ExtractAPIKeyForRateLimit` and `preExtractedAPIKey`
- **`coderd/httpmw/ratelimit.go`**: Rate limiter checks
`apiKeyPrecheckedContextKey` first, then `apiKeyContextKey` fallback
(for unit tests / workspace apps), then IP
- **`coderd/coderd.go`**: Added `PrecheckAPIKey` once at root
`r.Use(...)` block, removed `ExtractAPIKeyForRateLimit` from `/api/v2`
and `/api/experimental`
- **`coderd/coderd_test.go`**: `TestRateLimitByUser` regression test
with `BypassOwner` subtest

Fixes #20857
This commit is contained in:
Kacper Sawicki
2026-03-09 13:54:31 +01:00
committed by GitHub
parent 715486465b
commit 49006685b0
4 changed files with 525 additions and 213 deletions
+10 -4
View File
@@ -926,6 +926,16 @@ func New(options *Options) *API {
loggermw.Logger(api.Logger),
singleSlashMW,
rolestore.CustomRoleMW,
// Validate API key on every request (if present) and store
// the result in context. The rate limiter reads this to key
// by user ID, and downstream ExtractAPIKeyMW reuses it to
// avoid redundant DB lookups. Never rejects requests.
httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(),
Logger: options.Logger,
}),
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
prometheusMW,
// Build-Version is helpful for debugging.
@@ -1074,8 +1084,6 @@ func New(options *Options) *API {
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
r.Use(
// Specific routes can specify different limits, but every rate
// limit must be configurable by the admin.
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
@@ -1168,8 +1176,6 @@ func New(options *Options) *API {
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
r.Use(
// Specific routes can specify different limits, but every rate
// limit must be configurable by the admin.
apiRateLimiter,
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
)
+88
View File
@@ -416,3 +416,91 @@ func TestDERPMetrics(t *testing.T) {
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
"expected coder_derp_server_packets_dropped_reason_total to be registered")
}
// TestRateLimitByUser verifies that rate limiting keys by user ID when
// an authenticated session is present, rather than falling back to IP.
// This is a regression test for https://github.com/coder/coder/issues/20857
func TestRateLimitByUser(t *testing.T) {
t.Parallel()
const rateLimit = 5
ownerClient := coderdtest.New(t, &coderdtest.Options{
APIRateLimit: rateLimit,
})
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
t.Run("HitsLimit", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Make rateLimit requests — they should all succeed.
for i := 0; i < rateLimit; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"request %d should succeed", i+1)
}
// The next request should be rate-limited.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
"request should be rate limited")
})
t.Run("BypassOwner", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Owner with bypass header should not be rate-limited.
for i := 0; i < rateLimit+5; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := ownerClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode,
"owner bypass request %d should succeed", i+1)
}
})
t.Run("MemberCannotBypass", func(t *testing.T) {
t.Parallel()
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)
// A member requesting the bypass header should be rejected
// with 428 Precondition Required — only owners may bypass.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
memberClient.URL.String()+"/api/v2/buildinfo", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
resp, err := memberClient.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
"member should not be able to bypass rate limit")
})
}
+391 -194
View File
@@ -30,7 +30,57 @@ import (
"github.com/coder/coder/v2/codersdk"
)
type apiKeyContextKey struct{}
type (
apiKeyContextKey struct{}
apiKeyPrecheckedContextKey struct{}
)
// ValidateAPIKeyConfig holds the settings needed for API key
// validation at the top of the request lifecycle. Unlike
// ExtractAPIKeyConfig it omits route-specific fields
// (RedirectToLogin, Optional, ActivateDormantUser, etc.).
type ValidateAPIKeyConfig struct {
DB database.Store
OAuth2Configs *OAuth2Configs
DisableSessionExpiryRefresh bool
// SessionTokenFunc overrides how the API token is extracted
// from the request. Nil uses the default (cookie/header).
SessionTokenFunc func(*http.Request) string
Logger slog.Logger
}
// ValidateAPIKeyResult is the outcome of successful validation.
type ValidateAPIKeyResult struct {
Key database.APIKey
Subject rbac.Subject
UserStatus database.UserStatus
}
// ValidateAPIKeyError represents a validation failure with enough
// context for downstream middlewares to decide how to respond.
type ValidateAPIKeyError struct {
Code int
Response codersdk.Response
// Hard is true for server errors and active failures (5xx,
// OAuth refresh failures) that must be surfaced even on
// optional-auth routes. Soft errors (missing/expired token)
// may be swallowed on optional routes.
Hard bool
}
func (e *ValidateAPIKeyError) Error() string {
return e.Response.Message
}
// APIKeyPrechecked stores the result of top-level API key
// validation performed by PrecheckAPIKey. It distinguishes
// two states:
// - Validation failed (including no token): Result == nil && Err != nil
// - Validation passed: Result != nil && Err == nil
type APIKeyPrechecked struct {
Result *ValidateAPIKeyResult
Err *ValidateAPIKeyError
}
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
@@ -149,6 +199,298 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}
// PrecheckAPIKey extracts and fully validates the API key on every
// request (if present) and stores the result in context. It never
// writes error responses and always calls next.
//
// The rate limiter reads the stored result to key by user ID and
// check the Owner bypass header. Downstream ExtractAPIKeyMW reads
// it to avoid redundant DB lookups and validation.
func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Already prechecked (shouldn't happen, but guard).
if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
next.ServeHTTP(rw, r)
return
}
result, valErr := ValidateAPIKey(ctx, cfg, r)
prechecked := APIKeyPrechecked{
Result: result,
Err: valErr,
}
ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
// ValidateAPIKey extracts and validates the API key from the
// request. It performs all security-critical checks:
// - Token extraction and parsing
// - Database lookup + secret hash validation
// - Expiry check
// - OIDC/OAuth token refresh (if applicable)
// - API key LastUsed / ExpiresAt DB updates
// - User role lookup (UserRBACSubject)
//
// It does NOT:
// - Write HTTP error responses
// - Activate dormant users (route-specific)
// - Redirect to login (route-specific)
// - Check OAuth2 audience (route-specific, depends on AccessURL)
// - Set PostAuth headers (route-specific)
// - Check user active status (route-specific, depends on dormant activation)
//
// Returns (result, nil) on success or (nil, error) on failure.
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: resp,
}
}
// Log the API key ID for all requests that have a valid key
// format and secret, regardless of whether subsequent validation
// (expiry, user status, etc.) succeeds.
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
rl.WithFields(slog.F("api_key_id", key.ID))
}
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
},
}
}
// Refresh OIDC/GitHub tokens if applicable.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if errors.Is(err, sql.ErrNoRows) {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "You must re-authenticate with the login provider.",
},
}
}
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
},
Hard: true,
}
}
// Check if the OAuth token is expired.
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
},
Hard: true,
}
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
},
Hard: true,
}
}
if oauthConfig == nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
},
Hard: true,
}
}
// Soft error: session expired naturally with no
// refresh token. Optional-auth routes treat this as
// unauthenticated.
if link.OAuthRefreshToken == "" {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
},
}
}
// We have a refresh token, so let's try it.
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
// Hard error: we actively tried to refresh and the
// provider rejected it — surface even on optional-auth
// routes.
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
},
Hard: true,
}
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
},
Hard: true,
}
}
}
}
// Update LastUsed and session expiry.
changed := false
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
remoteIP := net.ParseIP(r.RemoteAddr)
if remoteIP == nil {
remoteIP = net.IPv4(0, 0, 0, 0)
}
bitlen := len(remoteIP) * 8
key.IPAddress = pqtype.Inet{
IPNet: net.IPNet{
IP: remoteIP,
Mask: net.CIDRMask(bitlen, bitlen),
},
Valid: true,
}
changed = true
}
if !cfg.DisableSessionExpiryRefresh {
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
}
if changed {
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
},
Hard: true,
}
}
//nolint:gocritic // system needs to update user last seen at
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusInternalServerError,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
},
Hard: true,
}
}
}
// Fetch user roles.
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
if err != nil {
return nil, &ValidateAPIKeyError{
Code: http.StatusUnauthorized,
Response: codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
},
Hard: true,
}
}
return &ValidateAPIKeyResult{
Key: *key,
Subject: actor,
UserStatus: userStatus,
}, nil
}
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
@@ -240,29 +582,60 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return optionalWrite(http.StatusUnauthorized, resp)
// --- Consume prechecked result if available ---
// Skip prechecked data when cfg has a custom SessionTokenFunc,
// because the precheck used the default token extraction and may
// have validated a different token (e.g. workspace app token
// issuance in workspaceapps/db.go).
var key *database.APIKey
var actor rbac.Subject
var userStatus database.UserStatus
var skipValidation bool
if cfg.SessionTokenFunc == nil {
if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
if pc.Err != nil {
// Validation failed at the top level (includes
// "no token provided").
if pc.Err.Hard {
return write(pc.Err.Code, pc.Err.Response)
}
return optionalWrite(pc.Err.Code, pc.Err.Response)
}
// Valid — use prechecked data, skip to route-specific logic.
key = &pc.Result.Key
actor = pc.Result.Subject
userStatus = pc.Result.UserStatus
skipValidation = true
}
}
// Log the API key ID for all requests that have a valid key format and secret,
// regardless of whether subsequent validation (expiry, user status, etc.) succeeds.
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
rl.WithFields(slog.F("api_key_id", key.ID))
if !skipValidation {
// Full validation path (no prechecked result or custom token func).
result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{
DB: cfg.DB,
OAuth2Configs: cfg.OAuth2Configs,
DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh,
SessionTokenFunc: cfg.SessionTokenFunc,
Logger: cfg.Logger,
}, r)
if valErr != nil {
if valErr.Hard {
return write(valErr.Code, valErr.Response)
}
return optionalWrite(valErr.Code, valErr.Response)
}
key = &result.Key
actor = result.Subject
userStatus = result.UserStatus
}
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// --- Route-specific logic (always runs) ---
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
// Validate OAuth2 provider app token audience (RFC 8707) if applicable.
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client
// Log the detailed error for debugging but don't expose it to the client.
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
Message: "Token audience validation failed",
@@ -270,183 +643,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}
}
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
var err error
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if errors.Is(err, sql.ErrNoRows) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "You must re-authenticate with the login provider.",
})
}
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
})
}
// Check if the OAuth token is expired
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
})
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
})
}
// It's possible for cfg.OAuth2Configs to be non-nil, but still
// missing this type. For example, if a user logged in with GitHub,
// but the administrator later removed GitHub and replaced it with
// OIDC.
if oauthConfig == nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
})
}
if link.OAuthRefreshToken == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
})
}
// We have a refresh token, so let's try it
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
})
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
}
// Tracks if the API key has properties updated
changed := false
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
remoteIP := net.ParseIP(r.RemoteAddr)
if remoteIP == nil {
remoteIP = net.IPv4(0, 0, 0, 0)
}
bitlen := len(remoteIP) * 8
key.IPAddress = pqtype.Inet{
IPNet: net.IPNet{
IP: remoteIP,
Mask: net.CIDRMask(bitlen, bitlen),
},
Valid: true,
}
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce re-authentication.
if !cfg.DisableSessionExpiryRefresh {
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
}
if changed {
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
}
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
// easier on the DB to colocate writes.
//nolint:gocritic // system needs to update user last seen at
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
})
}
}
// If the key is valid, we also fetch the user roles and status.
// The roles are used for RBAC authorize checks, and the status
// is to block 'suspended' users from accessing the platform.
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
})
}
// Dormant activation (config-dependent).
if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil {
id, _ := uuid.Parse(actor.ID)
user, err := cfg.ActivateDormantUser(ctx, database.User{
+36 -15
View File
@@ -32,35 +32,56 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
count,
window,
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
// Prioritize by user, but fallback to IP.
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
if !ok {
// Identify the caller. We check two sources:
//
// 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey
// at the root of the router. Only fully validated
// keys are used.
// 2. apiKeyContextKey — set by ExtractAPIKeyMW if it
// has already run (e.g. unit tests, workspace-app
// routes that don't go through PrecheckAPIKey).
//
// If neither is present, fall back to IP.
var userID string
var subject *rbac.Subject
if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil {
userID = pc.Result.Key.UserID.String()
subject = &pc.Result.Subject
} else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok {
userID = ak.UserID.String()
if auth, ok := UserAuthorizationOptional(r.Context()); ok {
subject = &auth
}
} else {
return httprate.KeyByIP(r)
}
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
// No bypass attempt, just ratelimit.
return apiKey.UserID.String(), nil
// No bypass attempt, just rate limit by user.
return userID, nil
}
// Allow Owner to bypass rate limiting for load tests
// and automation.
auth := UserAuthorization(r.Context())
// We avoid using rbac.Authorizer since rego is CPU-intensive
// and undermines the DoS-prevention goal of the rate limiter.
for _, role := range auth.SafeRoleNames() {
// and automation. We avoid using rbac.Authorizer since
// rego is CPU-intensive and undermines the
// DoS-prevention goal of the rate limiter.
if subject == nil {
// Can't verify roles — rate limit normally.
return userID, nil
}
for _, role := range subject.SafeRoleNames() {
if role == rbac.RoleOwner() {
// HACK: use a random key each time to
// de facto disable rate limiting. The
// `httprate` package has no
// support for selectively changing the limit
// for particular keys.
// httprate package has no support for
// selectively changing the limit for
// particular keys.
return cryptorand.String(16)
}
}
return apiKey.UserID.String(), xerrors.Errorf(
return userID, xerrors.Errorf(
"%q provided but user is not %v",
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
)