diff --git a/coderd/coderd.go b/coderd/coderd.go index ae6b0bc159..282c13c2e9 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -926,6 +926,16 @@ func New(options *Options) *API { loggermw.Logger(api.Logger), singleSlashMW, rolestore.CustomRoleMW, + // Validate API key on every request (if present) and store + // the result in context. The rate limiter reads this to key + // by user ID, and downstream ExtractAPIKeyMW reuses it to + // avoid redundant DB lookups. Never rejects requests. + httpmw.PrecheckAPIKey(httpmw.ValidateAPIKeyConfig{ + DB: options.Database, + OAuth2Configs: oauthConfigs, + DisableSessionExpiryRefresh: options.DeploymentValues.Sessions.DisableExpiryRefresh.Value(), + Logger: options.Logger, + }), httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware. prometheusMW, // Build-Version is helpful for debugging. @@ -1074,8 +1084,6 @@ func New(options *Options) *API { r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) }) r.Use( - // Specific routes can specify different limits, but every rate - // limit must be configurable by the admin. apiRateLimiter, httpmw.ReportCLITelemetry(api.Logger, options.Telemetry), ) @@ -1168,8 +1176,6 @@ func New(options *Options) *API { r.NotFound(func(rw http.ResponseWriter, _ *http.Request) { httpapi.RouteNotFound(rw) }) r.Use( - // Specific routes can specify different limits, but every rate - // limit must be configurable by the admin. apiRateLimiter, httpmw.ReportCLITelemetry(api.Logger, options.Telemetry), ) diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 49612b2b40..0ff2a65e2d 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -416,3 +416,91 @@ func TestDERPMetrics(t *testing.T) { assert.Contains(t, names, "coder_derp_server_packets_dropped_reason_total", "expected coder_derp_server_packets_dropped_reason_total to be registered") } + +// TestRateLimitByUser verifies that rate limiting keys by user ID when +// an authenticated session is present, rather than falling back to IP. +// This is a regression test for https://github.com/coder/coder/issues/20857 +func TestRateLimitByUser(t *testing.T) { + t.Parallel() + + const rateLimit = 5 + + ownerClient := coderdtest.New(t, &coderdtest.Options{ + APIRateLimit: rateLimit, + }) + firstUser := coderdtest.CreateFirstUser(t, ownerClient) + + t.Run("HitsLimit", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + // Make rateLimit requests — they should all succeed. + for i := 0; i < rateLimit; i++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + ownerClient.URL.String()+"/api/v2/buildinfo", nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken()) + + resp, err := ownerClient.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode, + "request %d should succeed", i+1) + } + + // The next request should be rate-limited. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + ownerClient.URL.String()+"/api/v2/buildinfo", nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken()) + + resp, err := ownerClient.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode, + "request should be rate limited") + }) + + t.Run("BypassOwner", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + // Owner with bypass header should not be rate-limited. + for i := 0; i < rateLimit+5; i++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + ownerClient.URL.String()+"/api/v2/buildinfo", nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, ownerClient.SessionToken()) + req.Header.Set(codersdk.BypassRatelimitHeader, "true") + + resp, err := ownerClient.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode, + "owner bypass request %d should succeed", i+1) + } + }) + + t.Run("MemberCannotBypass", func(t *testing.T) { + t.Parallel() + + memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID) + ctx := testutil.Context(t, testutil.WaitLong) + + // A member requesting the bypass header should be rejected + // with 428 Precondition Required — only owners may bypass. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + memberClient.URL.String()+"/api/v2/buildinfo", nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, memberClient.SessionToken()) + req.Header.Set(codersdk.BypassRatelimitHeader, "true") + + resp, err := memberClient.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode, + "member should not be able to bypass rate limit") + }) +} diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 4541925342..129c9c0c3d 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -30,7 +30,57 @@ import ( "github.com/coder/coder/v2/codersdk" ) -type apiKeyContextKey struct{} +type ( + apiKeyContextKey struct{} + apiKeyPrecheckedContextKey struct{} +) + +// ValidateAPIKeyConfig holds the settings needed for API key +// validation at the top of the request lifecycle. Unlike +// ExtractAPIKeyConfig it omits route-specific fields +// (RedirectToLogin, Optional, ActivateDormantUser, etc.). +type ValidateAPIKeyConfig struct { + DB database.Store + OAuth2Configs *OAuth2Configs + DisableSessionExpiryRefresh bool + // SessionTokenFunc overrides how the API token is extracted + // from the request. Nil uses the default (cookie/header). + SessionTokenFunc func(*http.Request) string + Logger slog.Logger +} + +// ValidateAPIKeyResult is the outcome of successful validation. +type ValidateAPIKeyResult struct { + Key database.APIKey + Subject rbac.Subject + UserStatus database.UserStatus +} + +// ValidateAPIKeyError represents a validation failure with enough +// context for downstream middlewares to decide how to respond. +type ValidateAPIKeyError struct { + Code int + Response codersdk.Response + // Hard is true for server errors and active failures (5xx, + // OAuth refresh failures) that must be surfaced even on + // optional-auth routes. Soft errors (missing/expired token) + // may be swallowed on optional routes. + Hard bool +} + +func (e *ValidateAPIKeyError) Error() string { + return e.Response.Message +} + +// APIKeyPrechecked stores the result of top-level API key +// validation performed by PrecheckAPIKey. It distinguishes +// two states: +// - Validation failed (including no token): Result == nil && Err != nil +// - Validation passed: Result != nil && Err == nil +type APIKeyPrechecked struct { + Result *ValidateAPIKeyResult + Err *ValidateAPIKeyError +} // APIKeyOptional may return an API key from the ExtractAPIKey handler. func APIKeyOptional(r *http.Request) (database.APIKey, bool) { @@ -149,6 +199,298 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { } } +// PrecheckAPIKey extracts and fully validates the API key on every +// request (if present) and stores the result in context. It never +// writes error responses and always calls next. +// +// The rate limiter reads the stored result to key by user ID and +// check the Owner bypass header. Downstream ExtractAPIKeyMW reads +// it to avoid redundant DB lookups and validation. +func PrecheckAPIKey(cfg ValidateAPIKeyConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Already prechecked (shouldn't happen, but guard). + if _, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok { + next.ServeHTTP(rw, r) + return + } + + result, valErr := ValidateAPIKey(ctx, cfg, r) + + prechecked := APIKeyPrechecked{ + Result: result, + Err: valErr, + } + ctx = context.WithValue(ctx, apiKeyPrecheckedContextKey{}, prechecked) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} + +// ValidateAPIKey extracts and validates the API key from the +// request. It performs all security-critical checks: +// - Token extraction and parsing +// - Database lookup + secret hash validation +// - Expiry check +// - OIDC/OAuth token refresh (if applicable) +// - API key LastUsed / ExpiresAt DB updates +// - User role lookup (UserRBACSubject) +// +// It does NOT: +// - Write HTTP error responses +// - Activate dormant users (route-specific) +// - Redirect to login (route-specific) +// - Check OAuth2 audience (route-specific, depends on AccessURL) +// - Set PostAuth headers (route-specific) +// - Check user active status (route-specific, depends on dormant activation) +// +// Returns (result, nil) on success or (nil, error) on failure. +func ValidateAPIKey(ctx context.Context, cfg ValidateAPIKeyConfig, r *http.Request) (*ValidateAPIKeyResult, *ValidateAPIKeyError) { + key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r) + if !ok { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: resp, + } + } + + // Log the API key ID for all requests that have a valid key + // format and secret, regardless of whether subsequent validation + // (expiry, user status, etc.) succeeds. + if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil { + rl.WithFields(slog.F("api_key_id", key.ID)) + } + + now := dbtime.Now() + if key.ExpiresAt.Before(now) { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()), + }, + } + } + + // Refresh OIDC/GitHub tokens if applicable. + if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { + //nolint:gocritic // System needs to fetch UserLink to check if it's valid. + link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: key.LoginType, + }) + if errors.Is(err, sql.ErrNoRows) { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: "You must re-authenticate with the login provider.", + }, + } + } + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: "A database error occurred", + Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()), + }, + Hard: true, + } + } + // Check if the OAuth token is expired. + if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) { + if cfg.OAuth2Configs.IsZero() { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+ + "No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType), + }, + Hard: true, + } + } + + var friendlyName string + var oauthConfig promoauth.OAuth2Config + switch key.LoginType { + case database.LoginTypeGithub: + oauthConfig = cfg.OAuth2Configs.Github + friendlyName = "GitHub" + case database.LoginTypeOIDC: + oauthConfig = cfg.OAuth2Configs.OIDC + friendlyName = "OpenID Connect" + default: + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType), + }, + Hard: true, + } + } + + if oauthConfig == nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+ + "OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType), + }, + Hard: true, + } + } + + // Soft error: session expired naturally with no + // refresh token. Optional-auth routes treat this as + // unauthenticated. + if link.OAuthRefreshToken == "" { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: codersdk.Response{ + Message: SignedOutErrorMessage, + Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()), + }, + } + } + + // We have a refresh token, so let's try it. + token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ + AccessToken: link.OAuthAccessToken, + RefreshToken: link.OAuthRefreshToken, + Expiry: link.OAuthExpiry, + }).Token() + // Hard error: we actively tried to refresh and the + // provider rejected it — surface even on optional-auth + // routes. + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: codersdk.Response{ + Message: fmt.Sprintf( + "Could not refresh expired %s token. Try re-authenticating to resolve this issue.", + friendlyName), + Detail: err.Error(), + }, + Hard: true, + } + } + link.OAuthAccessToken = token.AccessToken + link.OAuthRefreshToken = token.RefreshToken + link.OAuthExpiry = token.Expiry + //nolint:gocritic // system needs to update user link + link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ + UserID: link.UserID, + LoginType: link.LoginType, + OAuthAccessToken: link.OAuthAccessToken, + OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthExpiry: link.OAuthExpiry, + // Refresh should keep the same debug context because we use + // the original claims for the group/role sync. + Claims: link.Claims, + }) + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("update user_link: %s.", err.Error()), + }, + Hard: true, + } + } + } + } + + // Update LastUsed and session expiry. + changed := false + if now.Sub(key.LastUsed) > time.Hour { + key.LastUsed = now + remoteIP := net.ParseIP(r.RemoteAddr) + if remoteIP == nil { + remoteIP = net.IPv4(0, 0, 0, 0) + } + bitlen := len(remoteIP) * 8 + key.IPAddress = pqtype.Inet{ + IPNet: net.IPNet{ + IP: remoteIP, + Mask: net.CIDRMask(bitlen, bitlen), + }, + Valid: true, + } + changed = true + } + if !cfg.DisableSessionExpiryRefresh { + apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second + if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { + key.ExpiresAt = now.Add(apiKeyLifetime) + changed = true + } + } + if changed { + //nolint:gocritic // System needs to update API Key LastUsed + err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{ + ID: key.ID, + LastUsed: key.LastUsed, + ExpiresAt: key.ExpiresAt, + IPAddress: key.IPAddress, + }) + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()), + }, + Hard: true, + } + } + + //nolint:gocritic // system needs to update user last seen at + _, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{ + ID: key.UserID, + LastSeenAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + }) + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusInternalServerError, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()), + }, + Hard: true, + } + } + } + + // Fetch user roles. + actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet()) + if err != nil { + return nil, &ValidateAPIKeyError{ + Code: http.StatusUnauthorized, + Response: codersdk.Response{ + Message: internalErrorMessage, + Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()), + }, + Hard: true, + } + } + + return &ValidateAPIKeyResult{ + Key: *key, + Subject: actor, + UserStatus: userStatus, + }, nil +} + func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) { tokenFunc := APITokenFromRequest if sessionTokenFunc != nil { @@ -240,29 +582,60 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon return nil, nil, false } - key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r) - if !ok { - return optionalWrite(http.StatusUnauthorized, resp) + // --- Consume prechecked result if available --- + // Skip prechecked data when cfg has a custom SessionTokenFunc, + // because the precheck used the default token extraction and may + // have validated a different token (e.g. workspace app token + // issuance in workspaceapps/db.go). + var key *database.APIKey + var actor rbac.Subject + var userStatus database.UserStatus + var skipValidation bool + + if cfg.SessionTokenFunc == nil { + if pc, ok := ctx.Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok { + if pc.Err != nil { + // Validation failed at the top level (includes + // "no token provided"). + if pc.Err.Hard { + return write(pc.Err.Code, pc.Err.Response) + } + return optionalWrite(pc.Err.Code, pc.Err.Response) + } + // Valid — use prechecked data, skip to route-specific logic. + key = &pc.Result.Key + actor = pc.Result.Subject + userStatus = pc.Result.UserStatus + skipValidation = true + } } - // Log the API key ID for all requests that have a valid key format and secret, - // regardless of whether subsequent validation (expiry, user status, etc.) succeeds. - if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil { - rl.WithFields(slog.F("api_key_id", key.ID)) + if !skipValidation { + // Full validation path (no prechecked result or custom token func). + result, valErr := ValidateAPIKey(ctx, ValidateAPIKeyConfig{ + DB: cfg.DB, + OAuth2Configs: cfg.OAuth2Configs, + DisableSessionExpiryRefresh: cfg.DisableSessionExpiryRefresh, + SessionTokenFunc: cfg.SessionTokenFunc, + Logger: cfg.Logger, + }, r) + if valErr != nil { + if valErr.Hard { + return write(valErr.Code, valErr.Response) + } + return optionalWrite(valErr.Code, valErr.Response) + } + key = &result.Key + actor = result.Subject + userStatus = result.UserStatus } - now := dbtime.Now() - if key.ExpiresAt.Before(now) { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()), - }) - } + // --- Route-specific logic (always runs) --- - // Validate OAuth2 provider app token audience (RFC 8707) if applicable + // Validate OAuth2 provider app token audience (RFC 8707) if applicable. if key.LoginType == database.LoginTypeOAuth2ProviderApp { if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil { - // Log the detailed error for debugging but don't expose it to the client + // Log the detailed error for debugging but don't expose it to the client. cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err)) return optionalWrite(http.StatusForbidden, codersdk.Response{ Message: "Token audience validation failed", @@ -270,183 +643,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon } } - // We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor - // really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly - // refreshing the OIDC token. - if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC { - var err error - //nolint:gocritic // System needs to fetch UserLink to check if it's valid. - link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ - UserID: key.UserID, - LoginType: key.LoginType, - }) - if errors.Is(err, sql.ErrNoRows) { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: "You must re-authenticate with the login provider.", - }) - } - if err != nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: "A database error occurred", - Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()), - }) - } - // Check if the OAuth token is expired - if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) { - if cfg.OAuth2Configs.IsZero() { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+ - "No OAuth2Configs provided. Contact an administrator to configure this login type.", key.LoginType), - }) - } - - var friendlyName string - var oauthConfig promoauth.OAuth2Config - switch key.LoginType { - case database.LoginTypeGithub: - oauthConfig = cfg.OAuth2Configs.Github - friendlyName = "GitHub" - case database.LoginTypeOIDC: - oauthConfig = cfg.OAuth2Configs.OIDC - friendlyName = "OpenID Connect" - default: - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType), - }) - } - - // It's possible for cfg.OAuth2Configs to be non-nil, but still - // missing this type. For example, if a user logged in with GitHub, - // but the administrator later removed GitHub and replaced it with - // OIDC. - if oauthConfig == nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("Unable to refresh OAuth token for login type %q. "+ - "OAuth2Config not provided. Contact an administrator to configure this login type.", key.LoginType), - }) - } - - if link.OAuthRefreshToken == "" { - return optionalWrite(http.StatusUnauthorized, codersdk.Response{ - Message: SignedOutErrorMessage, - Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()), - }) - } - // We have a refresh token, so let's try it - token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{ - AccessToken: link.OAuthAccessToken, - RefreshToken: link.OAuthRefreshToken, - Expiry: link.OAuthExpiry, - }).Token() - if err != nil { - return write(http.StatusUnauthorized, codersdk.Response{ - Message: fmt.Sprintf( - "Could not refresh expired %s token. Try re-authenticating to resolve this issue.", - friendlyName), - Detail: err.Error(), - }) - } - link.OAuthAccessToken = token.AccessToken - link.OAuthRefreshToken = token.RefreshToken - link.OAuthExpiry = token.Expiry - //nolint:gocritic // system needs to update user link - link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ - UserID: link.UserID, - LoginType: link.LoginType, - OAuthAccessToken: link.OAuthAccessToken, - OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthExpiry: link.OAuthExpiry, - // Refresh should keep the same debug context because we use - // the original claims for the group/role sync. - Claims: link.Claims, - }) - if err != nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("update user_link: %s.", err.Error()), - }) - } - } - } - - // Tracks if the API key has properties updated - changed := false - - // Only update LastUsed once an hour to prevent database spam. - if now.Sub(key.LastUsed) > time.Hour { - key.LastUsed = now - remoteIP := net.ParseIP(r.RemoteAddr) - if remoteIP == nil { - remoteIP = net.IPv4(0, 0, 0, 0) - } - bitlen := len(remoteIP) * 8 - key.IPAddress = pqtype.Inet{ - IPNet: net.IPNet{ - IP: remoteIP, - Mask: net.CIDRMask(bitlen, bitlen), - }, - Valid: true, - } - changed = true - } - // Only update the ExpiresAt once an hour to prevent database spam. - // We extend the ExpiresAt to reduce re-authentication. - if !cfg.DisableSessionExpiryRefresh { - apiKeyLifetime := time.Duration(key.LifetimeSeconds) * time.Second - if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { - key.ExpiresAt = now.Add(apiKeyLifetime) - changed = true - } - } - if changed { - //nolint:gocritic // System needs to update API Key LastUsed - err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystemRestricted(ctx), database.UpdateAPIKeyByIDParams{ - ID: key.ID, - LastUsed: key.LastUsed, - ExpiresAt: key.ExpiresAt, - IPAddress: key.IPAddress, - }) - if err != nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()), - }) - } - - // We only want to update this occasionally to reduce DB write - // load. We update alongside the UserLink and APIKey since it's - // easier on the DB to colocate writes. - //nolint:gocritic // system needs to update user last seen at - _, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{ - ID: key.UserID, - LastSeenAt: dbtime.Now(), - UpdatedAt: dbtime.Now(), - }) - if err != nil { - return write(http.StatusInternalServerError, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("update user last_seen_at: %s", err.Error()), - }) - } - } - - // If the key is valid, we also fetch the user roles and status. - // The roles are used for RBAC authorize checks, and the status - // is to block 'suspended' users from accessing the platform. - actor, userStatus, err := UserRBACSubject(ctx, cfg.DB, key.UserID, key.ScopeSet()) - if err != nil { - return write(http.StatusUnauthorized, codersdk.Response{ - Message: internalErrorMessage, - Detail: fmt.Sprintf("Internal error fetching user's roles. %s", err.Error()), - }) - } - + // Dormant activation (config-dependent). if userStatus == database.UserStatusDormant && cfg.ActivateDormantUser != nil { id, _ := uuid.Parse(actor.ID) user, err := cfg.ActivateDormantUser(ctx, database.User{ diff --git a/coderd/httpmw/ratelimit.go b/coderd/httpmw/ratelimit.go index 51fdcfd74c..e89a280530 100644 --- a/coderd/httpmw/ratelimit.go +++ b/coderd/httpmw/ratelimit.go @@ -32,35 +32,56 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler count, window, httprate.WithKeyFuncs(func(r *http.Request) (string, error) { - // Prioritize by user, but fallback to IP. - apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey) - if !ok { + // Identify the caller. We check two sources: + // + // 1. apiKeyPrecheckedContextKey — set by PrecheckAPIKey + // at the root of the router. Only fully validated + // keys are used. + // 2. apiKeyContextKey — set by ExtractAPIKeyMW if it + // has already run (e.g. unit tests, workspace-app + // routes that don't go through PrecheckAPIKey). + // + // If neither is present, fall back to IP. + var userID string + var subject *rbac.Subject + + if pc, ok := r.Context().Value(apiKeyPrecheckedContextKey{}).(APIKeyPrechecked); ok && pc.Result != nil { + userID = pc.Result.Key.UserID.String() + subject = &pc.Result.Subject + } else if ak, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey); ok { + userID = ak.UserID.String() + if auth, ok := UserAuthorizationOptional(r.Context()); ok { + subject = &auth + } + } else { return httprate.KeyByIP(r) } if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok { - // No bypass attempt, just ratelimit. - return apiKey.UserID.String(), nil + // No bypass attempt, just rate limit by user. + return userID, nil } // Allow Owner to bypass rate limiting for load tests - // and automation. - auth := UserAuthorization(r.Context()) - - // We avoid using rbac.Authorizer since rego is CPU-intensive - // and undermines the DoS-prevention goal of the rate limiter. - for _, role := range auth.SafeRoleNames() { + // and automation. We avoid using rbac.Authorizer since + // rego is CPU-intensive and undermines the + // DoS-prevention goal of the rate limiter. + if subject == nil { + // Can't verify roles — rate limit normally. + return userID, nil + } + for _, role := range subject.SafeRoleNames() { if role == rbac.RoleOwner() { // HACK: use a random key each time to // de facto disable rate limiting. The - // `httprate` package has no - // support for selectively changing the limit - // for particular keys. + // httprate package has no support for + // selectively changing the limit for + // particular keys. return cryptorand.String(16) } } - return apiKey.UserID.String(), xerrors.Errorf( + return userID, xerrors.Errorf( "%q provided but user is not %v", codersdk.BypassRatelimitHeader, rbac.RoleOwner(), )