mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: rate limit by user instead of IP for authenticated requests (#22049)
## Problem Rate limiting by user is broken (#20857). The rate limit middleware runs before API key extraction, so user ID is never in the request context. This causes: - Rate limiting falls back to IP address for all requests - `X-Coder-Bypass-Ratelimit` header for Owners is ignored (can't verify role without identity) ## Solution Adds `PrecheckAPIKey`, a **root-level middleware** that fully validates the API key on every request (expiry, OIDC refresh, DB updates, role lookup) and stores the result in context. Added **once** at the root router — not duplicated per route group. ### Architecture ``` Request → Root middleware stack: → ExtractRealIP, Logger, ... → PrecheckAPIKey(...) ← validates key, stores result, never rejects → HandleSubdomain(apiRateLimiter) ← workspace apps now also benefit → CORS, CSRF → /api/v2 or /api/experimental: → apiRateLimiter ← reads prechecked result from context → route handlers: → ExtractAPIKeyMW ← reuses prechecked data, adds route-specific logic → handler ``` ### Key design decisions | Decision | Rationale | |---|---| | **Full validation, not lightweight** | Spike's review: "the whole idea of a 'lightweight' extraction that skips security checks is fundamentally flawed." Only fully validated keys are used for rate limiting — expired/invalid keys fall back to IP. | | **Structured error results** | `ValidateAPIKeyError` has a `Hard` flag that maps to `write` vs `optionalWrite`. Hard errors (5xx, OAuth refresh failures) surface even on optional-auth routes. Soft errors (missing/expired token) are swallowed on optional routes. | | **Added once at the root** | Spike's review: "Why can't we add it once at the root?" Root placement means workspace app rate limiters also benefit. | | **Skip prechecked when `SessionTokenFunc != nil`** | `workspaceapps/db.go` uses a custom `SessionTokenFunc` that extracts from `issueReq.SessionToken`. The prechecked result may have validated a different token. Falls back to `ValidateAPIKey` with the custom func. | | **User status check stays in `ExtractAPIKey`** | Dormant activation is route-specific — `ValidateAPIKey` stores status but doesn't enforce it. | | **Audience validation stays in `ExtractAPIKey`** | Depends on `cfg.AccessURL` and request path, uses `optionalWrite(403)` which depends on route config. | ### Changes - **`coderd/httpmw/apikey.go`**: - New `ValidateAPIKey` function — extracted core validation logic, returns structured errors instead of writing HTTP responses - New `PrecheckAPIKey` middleware — calls `ValidateAPIKey`, stores result in `apiKeyPrecheckedContextKey`, never rejects - New types: `ValidateAPIKeyConfig`, `ValidateAPIKeyResult`, `ValidateAPIKeyError`, `APIKeyPrechecked` - Refactored `ExtractAPIKey` — consumes prechecked result from context (skipping redundant validation), falls back to `ValidateAPIKey` when no precheck available - Removed `ExtractAPIKeyForRateLimit` and `preExtractedAPIKey` - **`coderd/httpmw/ratelimit.go`**: Rate limiter checks `apiKeyPrecheckedContextKey` first, then `apiKeyContextKey` fallback (for unit tests / workspace apps), then IP - **`coderd/coderd.go`**: Added `PrecheckAPIKey` once at root `r.Use(...)` block, removed `ExtractAPIKeyForRateLimit` from `/api/v2` and `/api/experimental` - **`coderd/coderd_test.go`**: `TestRateLimitByUser` regression test with `BypassOwner` subtest Fixes #20857
This commit is contained in:
+10
-4
@@ -926,6 +926,16 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
// Validate API key on every request (if present) and store
|
||||
// the result in context. The rate limiter reads this to key
|
||||
// by user ID, and downstream ExtractAPIKeyMW reuses it to
|
||||
// avoid redundant DB lookups. Never rejects requests.
|
||||
httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(),
|
||||
Logger: options.Logger,
|
||||
}),
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
@@ -1074,8 +1084,6 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
@@ -1168,8 +1176,6 @@ func New(options *Options) *API {
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
// Specific routes can specify different limits, but every rate
|
||||
// limit must be configurable by the admin.
|
||||
apiRateLimiter,
|
||||
httpmw.ReportCLITelemetry(api.Logger, options.Telemetry),
|
||||
)
|
||||
|
||||
@@ -416,3 +416,91 @@ func TestDERPMetrics(t *testing.T) {
|
||||
assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total",
|
||||
"expected coder_derp_server_packets_dropped_reason_total to be registered")
|
||||
}
|
||||
|
||||
// TestRateLimitByUser verifies that rate limiting keys by user ID when
|
||||
// an authenticated session is present, rather than falling back to IP.
|
||||
// This is a regression test for https://github.com/coder/coder/issues/20857
|
||||
func TestRateLimitByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const rateLimit = 5
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
APIRateLimit: rateLimit,
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
|
||||
t.Run("HitsLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Make rateLimit requests — they should all succeed.
|
||||
for i := 0; i < rateLimit; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"request %d should succeed", i+1)
|
||||
}
|
||||
|
||||
// The next request should be rate-limited.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode,
|
||||
"request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("BypassOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Owner with bypass header should not be rate-limited.
|
||||
for i := 0; i < rateLimit+5; i++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
ownerClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := ownerClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode,
|
||||
"owner bypass request %d should succeed", i+1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MemberCannotBypass", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// A member requesting the bypass header should be rejected
|
||||
// with 428 Precondition Required — only owners may bypass.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
memberClient.URL.String()+"/api/v2/buildinfo", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken())
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
|
||||
resp, err := memberClient.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode,
|
||||
"member should not be able to bypass rate limit")
|
||||
})
|
||||
}
|
||||
|
||||
+391
-194
@@ -30,7 +30,57 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
type (
|
||||
apiKeyContextKey struct{}
|
||||
apiKeyPrecheckedContextKey struct{}
|
||||
)
|
||||
|
||||
// ValidateAPIKeyConfig holds the settings needed for API key
|
||||
// validation at the top of the request lifecycle. Unlike
|
||||
// ExtractAPIKeyConfig it omits route-specific fields
|
||||
// (RedirectToLogin, Optional, ActivateDormantUser, etc.).
|
||||
type ValidateAPIKeyConfig struct {
|
||||
DB database.Store
|
||||
OAuth2Configs *OAuth2Configs
|
||||
DisableSessionExpiryRefresh bool
|
||||
// SessionTokenFunc overrides how the API token is extracted
|
||||
// from the request. Nil uses the default (cookie/header).
|
||||
SessionTokenFunc func(*http.Request) string
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// ValidateAPIKeyResult is the outcome of successful validation.
|
||||
type ValidateAPIKeyResult struct {
|
||||
Key database.APIKey
|
||||
Subject rbac.Subject
|
||||
UserStatus database.UserStatus
|
||||
}
|
||||
|
||||
// ValidateAPIKeyError represents a validation failure with enough
|
||||
// context for downstream middlewares to decide how to respond.
|
||||
type ValidateAPIKeyError struct {
|
||||
Code int
|
||||
Response codersdk.Response
|
||||
// Hard is true for server errors and active failures (5xx,
|
||||
// OAuth refresh failures) that must be surfaced even on
|
||||
// optional-auth routes. Soft errors (missing/expired token)
|
||||
// may be swallowed on optional routes.
|
||||
Hard bool
|
||||
}
|
||||
|
||||
func (e *ValidateAPIKeyError) Error() string {
|
||||
return e.Response.Message
|
||||
}
|
||||
|
||||
// APIKeyPrechecked stores the result of top-level API key
|
||||
// validation performed by PrecheckAPIKey. It distinguishes
|
||||
// two states:
|
||||
// - Validation failed (including no token): Result == nil && Err != nil
|
||||
// - Validation passed: Result != nil && Err == nil
|
||||
type APIKeyPrechecked struct {
|
||||
Result *ValidateAPIKeyResult
|
||||
Err *ValidateAPIKeyError
|
||||
}
|
||||
|
||||
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
|
||||
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
|
||||
@@ -149,6 +199,298 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// PrecheckAPIKey extracts and fully validates the API key on every
|
||||
// request (if present) and stores the result in context. It never
|
||||
// writes error responses and always calls next.
|
||||
//
|
||||
// The rate limiter reads the stored result to key by user ID and
|
||||
// check the Owner bypass header. Downstream ExtractAPIKeyMW reads
|
||||
// it to avoid redundant DB lookups and validation.
|
||||
func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Already prechecked (shouldn't happen, but guard).
|
||||
if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
result, valErr := ValidateAPIKey(ctx, cfg, r)
|
||||
|
||||
prechecked := APIKeyPrechecked{
|
||||
Result: result,
|
||||
Err: valErr,
|
||||
}
|
||||
ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAPIKey extracts and validates the API key from the
|
||||
// request. It performs all security-critical checks:
|
||||
// - Token extraction and parsing
|
||||
// - Database lookup + secret hash validation
|
||||
// - Expiry check
|
||||
// - OIDC/OAuth token refresh (if applicable)
|
||||
// - API key LastUsed / ExpiresAt DB updates
|
||||
// - User role lookup (UserRBACSubject)
|
||||
//
|
||||
// It does NOT:
|
||||
// - Write HTTP error responses
|
||||
// - Activate dormant users (route-specific)
|
||||
// - Redirect to login (route-specific)
|
||||
// - Check OAuth2 audience (route-specific, depends on AccessURL)
|
||||
// - Set PostAuth headers (route-specific)
|
||||
// - Check user active status (route-specific, depends on dormant activation)
|
||||
//
|
||||
// Returns (result, nil) on success or (nil, error) on failure.
|
||||
func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) {
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: resp,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key
|
||||
// format and secret, regardless of whether subsequent validation
|
||||
// (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh OIDC/GitHub tokens if applicable.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
},
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
// Check if the OAuth token is expired.
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
if oauthConfig == nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Soft error: session expired naturally with no
|
||||
// refresh token. Optional-auth routes treat this as
|
||||
// unauthenticated.
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// We have a refresh token, so let's try it.
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
// Hard error: we actively tried to refresh and the
|
||||
// provider rejected it — surface even on optional-auth
|
||||
// routes.
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update LastUsed and session expiry.
|
||||
changed := false
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch user roles.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return nil, &ValidateAPIKeyError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Response: codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
},
|
||||
Hard: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ValidateAPIKeyResult{
|
||||
Key: *key,
|
||||
Subject: actor,
|
||||
UserStatus: userStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
@@ -240,29 +582,60 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return optionalWrite(http.StatusUnauthorized, resp)
|
||||
// --- Consume prechecked result if available ---
|
||||
// Skip prechecked data when cfg has a custom SessionTokenFunc,
|
||||
// because the precheck used the default token extraction and may
|
||||
// have validated a different token (e.g. workspace app token
|
||||
// issuance in workspaceapps/db.go).
|
||||
var key *database.APIKey
|
||||
var actor rbac.Subject
|
||||
var userStatus database.UserStatus
|
||||
var skipValidation bool
|
||||
|
||||
if cfg.SessionTokenFunc == nil {
|
||||
if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok {
|
||||
if pc.Err != nil {
|
||||
// Validation failed at the top level (includes
|
||||
// "no token provided").
|
||||
if pc.Err.Hard {
|
||||
return write(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
return optionalWrite(pc.Err.Code, pc.Err.Response)
|
||||
}
|
||||
// Valid — use prechecked data, skip to route-specific logic.
|
||||
key = &pc.Result.Key
|
||||
actor = pc.Result.Subject
|
||||
userStatus = pc.Result.UserStatus
|
||||
skipValidation = true
|
||||
}
|
||||
}
|
||||
|
||||
// Log the API key ID for all requests that have a valid key format and secret,
|
||||
// regardless of whether subsequent validation (expiry, user status, etc.) succeeds.
|
||||
if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil {
|
||||
rl.WithFields(slog.F("api_key_id", key.ID))
|
||||
if !skipValidation {
|
||||
// Full validation path (no prechecked result or custom token func).
|
||||
result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{
|
||||
DB: cfg.DB,
|
||||
OAuth2Configs: cfg.OAuth2Configs,
|
||||
DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh,
|
||||
SessionTokenFunc: cfg.SessionTokenFunc,
|
||||
Logger: cfg.Logger,
|
||||
}, r)
|
||||
if valErr != nil {
|
||||
if valErr.Hard {
|
||||
return write(valErr.Code, valErr.Response)
|
||||
}
|
||||
return optionalWrite(valErr.Code, valErr.Response)
|
||||
}
|
||||
key = &result.Key
|
||||
actor = result.Subject
|
||||
userStatus = result.UserStatus
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
}
|
||||
// --- Route-specific logic (always runs) ---
|
||||
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable.
|
||||
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
|
||||
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
|
||||
// Log the detailed error for debugging but don't expose it to the client
|
||||
// Log the detailed error for debugging but don't expose it to the client.
|
||||
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
|
||||
return optionalWrite(http.StatusForbidden, codersdk.Response{
|
||||
Message: "Token audience validation failed",
|
||||
@@ -270,183 +643,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
}
|
||||
}
|
||||
|
||||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
|
||||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
|
||||
// refreshing the OIDC token.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
var err error
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "You must re-authenticate with the login provider.",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
// Check if the OAuth token is expired
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
// It's possible for cfg.OAuth2Configs to be non-nil, but still
|
||||
// missing this type. For example, if a user logged in with GitHub,
|
||||
// but the administrator later removed GitHub and replaced it with
|
||||
// OIDC.
|
||||
if oauthConfig == nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+
|
||||
"OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType),
|
||||
})
|
||||
}
|
||||
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
})
|
||||
}
|
||||
// We have a refresh token, so let's try it
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks if the API key has properties updated
|
||||
changed := false
|
||||
|
||||
// Only update LastUsed once an hour to prevent database spam.
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
remoteIP := net.ParseIP(r.RemoteAddr)
|
||||
if remoteIP == nil {
|
||||
remoteIP = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(remoteIP) * 8
|
||||
key.IPAddress = pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: remoteIP,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce re-authentication.
|
||||
if !cfg.DisableSessionExpiryRefresh {
|
||||
apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
// easier on the DB to colocate writes.
|
||||
//nolint:gocritic // system needs to update user last seen at
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet())
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// Dormant activation (config-dependent).
|
||||
if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil {
|
||||
id, _ := uuid.Parse(actor.ID)
|
||||
user, err := cfg.ActivateDormantUser(ctx, database.User{
|
||||
|
||||
+36
-15
@@ -32,35 +32,56 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
|
||||
count,
|
||||
window,
|
||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||
// Prioritize by user, but fallback to IP.
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
if !ok {
|
||||
// Identify the caller. We check two sources:
|
||||
//
|
||||
// 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey
|
||||
// at the root of the router. Only fully validated
|
||||
// keys are used.
|
||||
// 2. apiKeyContextKey — set by ExtractAPIKeyMW if it
|
||||
// has already run (e.g. unit tests, workspace-app
|
||||
// routes that don't go through PrecheckAPIKey).
|
||||
//
|
||||
// If neither is present, fall back to IP.
|
||||
var userID string
|
||||
var subject *rbac.Subject
|
||||
|
||||
if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil {
|
||||
userID = pc.Result.Key.UserID.String()
|
||||
subject = &pc.Result.Subject
|
||||
} else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok {
|
||||
userID = ak.UserID.String()
|
||||
if auth, ok := UserAuthorizationOptional(r.Context()); ok {
|
||||
subject = &auth
|
||||
}
|
||||
} else {
|
||||
return httprate.KeyByIP(r)
|
||||
}
|
||||
|
||||
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
|
||||
// No bypass attempt, just ratelimit.
|
||||
return apiKey.UserID.String(), nil
|
||||
// No bypass attempt, just rate limit by user.
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Allow Owner to bypass rate limiting for load tests
|
||||
// and automation.
|
||||
auth := UserAuthorization(r.Context())
|
||||
|
||||
// We avoid using rbac.Authorizer since rego is CPU-intensive
|
||||
// and undermines the DoS-prevention goal of the rate limiter.
|
||||
for _, role := range auth.SafeRoleNames() {
|
||||
// and automation. We avoid using rbac.Authorizer since
|
||||
// rego is CPU-intensive and undermines the
|
||||
// DoS-prevention goal of the rate limiter.
|
||||
if subject == nil {
|
||||
// Can't verify roles — rate limit normally.
|
||||
return userID, nil
|
||||
}
|
||||
for _, role := range subject.SafeRoleNames() {
|
||||
if role == rbac.RoleOwner() {
|
||||
// HACK: use a random key each time to
|
||||
// de facto disable rate limiting. The
|
||||
// `httprate` package has no
|
||||
// support for selectively changing the limit
|
||||
// for particular keys.
|
||||
// httprate package has no support for
|
||||
// selectively changing the limit for
|
||||
// particular keys.
|
||||
return cryptorand.String(16)
|
||||
}
|
||||
}
|
||||
|
||||
return apiKey.UserID.String(), xerrors.Errorf(
|
||||
return userID, xerrors.Errorf(
|
||||
"%q provided but user is not %v",
|
||||
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user