diff --git a/cli/testdata/coder_server_--help.golden b/cli/testdata/coder_server_--help.golden index aa318a5f85..9abbcf9789 100644 --- a/cli/testdata/coder_server_--help.golden +++ b/cli/testdata/coder_server_--help.golden @@ -125,12 +125,20 @@ AI BRIDGE OPTIONS: requests (requires the "oauth2" and "mcp-server-http" experiments to be enabled). + --aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0) + Maximum number of concurrent AI Bridge requests per replica. Set to 0 + to disable (unlimited). + --aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/) The base URL of the OpenAI API. --aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY The key to authenticate against the OpenAI API. + --aibridge-rate-limit int, $CODER_AIBRIDGE_RATE_LIMIT (default: 0) + Maximum number of AI Bridge requests per second per replica. Set to 0 + to disable (unlimited). + CLIENT OPTIONS: These options change the behavior of how clients interact with the Coder. Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI. diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index a9e6058a3e..2b072b68a5 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -748,6 +748,14 @@ aibridge: # (token, prompt, tool use). # (default: 60d, type: duration) retention: 1440h0m0s + # Maximum number of concurrent AI Bridge requests per replica. Set to 0 to disable + # (unlimited). + # (default: 0, type: int) + maxConcurrency: 0 + # Maximum number of AI Bridge requests per second per replica. Set to 0 to disable + # (unlimited). + # (default: 0, type: int) + rateLimit: 0 # Configure data retention policies for various database tables. Retention # policies automatically purge old data to reduce database size and improve # performance. Setting a retention duration to 0 disables automatic purging for diff --git a/coderd/aibridge/aibridge.go b/coderd/aibridge/aibridge.go new file mode 100644 index 0000000000..cb656b5ec5 --- /dev/null +++ b/coderd/aibridge/aibridge.go @@ -0,0 +1,24 @@ +// Package aibridge provides utilities for the AI Bridge feature. +package aibridge + +import ( + "net/http" + "strings" +) + +// ExtractAuthToken extracts an authorization token from HTTP headers. +// It checks the Authorization header (Bearer token) and X-Api-Key header, +// which represent the different ways clients authenticate against AI providers. +// If neither are present, an empty string is returned. +func ExtractAuthToken(header http.Header) string { + if auth := strings.TrimSpace(header.Get("Authorization")); auth != "" { + fields := strings.Fields(auth) + if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") { + return fields[1] + } + } + if apiKey := strings.TrimSpace(header.Get("X-Api-Key")); apiKey != "" { + return apiKey + } + return "" +} diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index b8e3331ecd..800058b759 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -11883,9 +11883,15 @@ const docTemplate = `{ "inject_coder_mcp_tools": { "type": "boolean" }, + "max_concurrency": { + "type": "integer" + }, "openai": { "$ref": "#/definitions/codersdk.AIBridgeOpenAIConfig" }, + "rate_limit": { + "type": "integer" + }, "retention": { "type": "integer" } diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 396a704a06..ea54c1b219 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -10549,9 +10549,15 @@ "inject_coder_mcp_tools": { "type": "boolean" }, + "max_concurrency": { + "type": "integer" + }, "openai": { "$ref": "#/definitions/codersdk.AIBridgeOpenAIConfig" }, + "rate_limit": { + "type": "integer" + }, "retention": { "type": "integer" } diff --git a/coderd/httpmw/ratelimit.go b/coderd/httpmw/ratelimit.go index ad1ecf3d6b..51fdcfd74c 100644 --- a/coderd/httpmw/ratelimit.go +++ b/coderd/httpmw/ratelimit.go @@ -4,11 +4,13 @@ import ( "fmt" "net/http" "strconv" + "sync/atomic" "time" "github.com/go-chi/httprate" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/rbac" @@ -70,3 +72,72 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler }), ) } + +// RateLimitByAuthToken returns a handler that limits requests based on the +// authentication token in the request. +// +// This differs from [RateLimit] in several ways: +// - It extracts the token directly from request headers (Authorization Bearer +// or X-Api-Key) rather than from the request context, making it suitable for +// endpoints that handle authentication internally (like AI Bridge) rather than +// via [ExtractAPIKeyMW] middleware. +// - It does not support the bypass header for Owners. +// - It does not key by endpoint, so the limit applies across all endpoints using +// this middleware. +// - It includes a Retry-After header in 429 responses for backpressure signaling. +// +// If no token is found in the headers, it falls back to rate limiting by IP address. +func RateLimitByAuthToken(count int, window time.Duration) func(http.Handler) http.Handler { + if count <= 0 { + return func(handler http.Handler) http.Handler { + return handler + } + } + + return httprate.Limit( + count, + window, + httprate.WithKeyFuncs(func(r *http.Request) (string, error) { + // Try to extract auth token for per-user rate limiting using + // AI provider authentication headers (Authorization Bearer or X-Api-Key). + if token := aibridge.ExtractAuthToken(r.Header); token != "" { + return token, nil + } + // Fall back to IP-based rate limiting if no token present. + return httprate.KeyByIP(r) + }), + httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { + // Add Retry-After header for backpressure signaling. + w.Header().Set("Retry-After", fmt.Sprintf("%d", int(window.Seconds()))) + httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{ + Message: "You've been rate limited. Please try again later.", + }) + }), + ) +} + +// ConcurrencyLimit returns a handler that limits the number of concurrent +// requests. When the limit is exceeded, it returns HTTP 503 Service Unavailable. +func ConcurrencyLimit(maxConcurrent int64, resourceName string) func(http.Handler) http.Handler { + if maxConcurrent <= 0 { + return func(handler http.Handler) http.Handler { + return handler + } + } + + var current atomic.Int64 + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := current.Add(1) + defer current.Add(-1) + + if c > maxConcurrent { + httpapi.Write(r.Context(), w, http.StatusServiceUnavailable, codersdk.Response{ + Message: fmt.Sprintf("%s is currently at capacity. Please try again later.", resourceName), + }) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 51a05940fc..49e46ccf46 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" ) func randRemoteAddr() string { @@ -145,3 +147,211 @@ func TestRateLimit(t *testing.T) { } }) } + +func TestRateLimitByAuthToken(t *testing.T) { + t.Parallel() + + t.Run("LimitsByAuthHeader", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headerName string + headerVal string + }{ + { + name: "BearerToken", + headerName: "Authorization", + headerVal: "Bearer test-token-123", + }, + { + name: "XApiKey", + headerName: "X-Api-Key", + headerVal: "test-api-key-456", + }, + { + name: "NoToken", + headerName: "", + headerVal: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.RateLimitByAuthToken(2, time.Hour)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + // Same token (or IP if no token) should be rate limited after 2 requests. + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/", nil) + if tt.headerName != "" { + req.Header.Set(tt.headerName, tt.headerVal) + } + rec := httptest.NewRecorder() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + _ = resp.Body.Close() + if i < 2 { + require.Equal(t, http.StatusOK, resp.StatusCode, "request %d should succeed", i) + } else { + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "request %d should be rate limited", i) + // Verify Retry-After header is set. + require.NotEmpty(t, resp.Header.Get("Retry-After"), "Retry-After header should be set") + } + } + }) + } + }) + + t.Run("DifferentTokensNotLimited", func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.RateLimitByAuthToken(1, time.Hour)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + // Different tokens should not be rate limited against each other. + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer token-%d", i)) + rec := httptest.NewRecorder() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + _ = resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode, "request %d should succeed", i) + } + }) + + t.Run("DisabledWhenZero", func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.RateLimitByAuthToken(0, time.Hour)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + // Should not be rate limited when limit is 0. + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer same-token") + rec := httptest.NewRecorder() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + _ = resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + } + }) +} + +func TestConcurrencyLimit(t *testing.T) { + t.Parallel() + + t.Run("LimitsConcurrentRequests", func(t *testing.T) { + t.Parallel() + + const maxConcurrency = 2 + rtr := chi.NewRouter() + rtr.Use(httpmw.ConcurrencyLimit(maxConcurrency, "Test")) + + // Use a WaitGroup as a barrier to ensure all requests are in the handler + // before any of them proceed. + var handlersReady sync.WaitGroup + handlersReady.Add(maxConcurrency) + releaseHandler := make(chan struct{}) + + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + handlersReady.Done() + // Wait until released. + <-releaseHandler + rw.WriteHeader(http.StatusOK) + }) + + server := httptest.NewServer(rtr) + defer server.Close() + + ctx := testutil.Context(t, testutil.WaitShort) + + // Start maxConcurrency requests that will block. + // We use channels to collect errors instead of require in goroutines. + type result struct { + statusCode int + err error + } + results := make(chan result, maxConcurrency) + + var wg sync.WaitGroup + for i := 0; i < maxConcurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+"/", nil) + if err != nil { + results <- result{err: err} + return + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + results <- result{err: err} + return + } + defer resp.Body.Close() + results <- result{statusCode: resp.StatusCode} + }() + } + + // Wait for all requests to enter the handler with a timeout. + handlersReadyCh := make(chan struct{}) + go func() { + handlersReady.Wait() + close(handlersReadyCh) + }() + select { + case <-handlersReadyCh: + case <-ctx.Done(): + t.Fatal("timed out waiting for handlers to be ready") + } + + // Next request should be rejected since we're at capacity. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+"/", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + + // Release all blocked requests. + close(releaseHandler) + wg.Wait() + close(results) + + // Check all goroutine results. + for res := range results { + require.NoError(t, res.err) + require.Equal(t, http.StatusOK, res.statusCode) + } + }) + + t.Run("DisabledWhenZero", func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.ConcurrencyLimit(0, "Test")) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + // Should not be limited when maxConcurrency is 0. + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + _ = resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + } + }) +} diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 0dd082ab5e..4b42896ddd 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -3401,6 +3401,26 @@ Write out the current server config as YAML to stdout.`, YAML: "retention", Annotations: serpent.Annotations{}.Mark(annotationFormatDuration, "true"), }, + { + Name: "AI Bridge Max Concurrency", + Description: "Maximum number of concurrent AI Bridge requests per replica. Set to 0 to disable (unlimited).", + Flag: "aibridge-max-concurrency", + Env: "CODER_AIBRIDGE_MAX_CONCURRENCY", + Value: &c.AI.BridgeConfig.MaxConcurrency, + Default: "0", + Group: &deploymentGroupAIBridge, + YAML: "maxConcurrency", + }, + { + Name: "AI Bridge Rate Limit", + Description: "Maximum number of AI Bridge requests per second per replica. Set to 0 to disable (unlimited).", + Flag: "aibridge-rate-limit", + Env: "CODER_AIBRIDGE_RATE_LIMIT", + Value: &c.AI.BridgeConfig.RateLimit, + Default: "0", + Group: &deploymentGroupAIBridge, + YAML: "rateLimit", + }, // Retention settings { Name: "Audit Logs Retention", @@ -3471,6 +3491,8 @@ type AIBridgeConfig struct { Bedrock AIBridgeBedrockConfig `json:"bedrock" typescript:",notnull"` InjectCoderMCPTools serpent.Bool `json:"inject_coder_mcp_tools" typescript:",notnull"` Retention serpent.Duration `json:"retention" typescript:",notnull"` + MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"` + RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"` } type AIBridgeOpenAIConfig struct { diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 3ea0180ae1..3566670715 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -176,10 +176,12 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ }, "enabled": true, "inject_coder_mcp_tools": true, + "max_concurrency": 0, "openai": { "base_url": "string", "key": "string" }, + "rate_limit": 0, "retention": 0 } }, diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index bd00d79c4b..892e52c4e2 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -390,10 +390,12 @@ }, "enabled": true, "inject_coder_mcp_tools": true, + "max_concurrency": 0, "openai": { "base_url": "string", "key": "string" }, + "rate_limit": 0, "retention": 0 } ``` @@ -406,7 +408,9 @@ | `bedrock` | [codersdk.AIBridgeBedrockConfig](#codersdkaibridgebedrockconfig) | false | | | | `enabled` | boolean | false | | | | `inject_coder_mcp_tools` | boolean | false | | | +| `max_concurrency` | integer | false | | | | `openai` | [codersdk.AIBridgeOpenAIConfig](#codersdkaibridgeopenaiconfig) | false | | | +| `rate_limit` | integer | false | | | | `retention` | integer | false | | | ## codersdk.AIBridgeInterception @@ -700,10 +704,12 @@ }, "enabled": true, "inject_coder_mcp_tools": true, + "max_concurrency": 0, "openai": { "base_url": "string", "key": "string" }, + "rate_limit": 0, "retention": 0 } } @@ -2860,10 +2866,12 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o }, "enabled": true, "inject_coder_mcp_tools": true, + "max_concurrency": 0, "openai": { "base_url": "string", "key": "string" }, + "rate_limit": 0, "retention": 0 } }, @@ -3383,10 +3391,12 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o }, "enabled": true, "inject_coder_mcp_tools": true, + "max_concurrency": 0, "openai": { "base_url": "string", "key": "string" }, + "rate_limit": 0, "retention": 0 } }, diff --git a/docs/reference/cli/server.md b/docs/reference/cli/server.md index 4ba8c026fb..454d11bff6 100644 --- a/docs/reference/cli/server.md +++ b/docs/reference/cli/server.md @@ -1781,6 +1781,28 @@ Whether to inject Coder's MCP tools into intercepted AI Bridge requests (require Length of time to retain data such as interceptions and all related records (token, prompt, tool use). +### --aibridge-max-concurrency + +| | | +|-------------|----------------------------------------------| +| Type | int | +| Environment | $CODER_AIBRIDGE_MAX_CONCURRENCY | +| YAML | aibridge.maxConcurrency | +| Default | 0 | + +Maximum number of concurrent AI Bridge requests per replica. Set to 0 to disable (unlimited). + +### --aibridge-rate-limit + +| | | +|-------------|-----------------------------------------| +| Type | int | +| Environment | $CODER_AIBRIDGE_RATE_LIMIT | +| YAML | aibridge.rateLimit | +| Default | 0 | + +Maximum number of AI Bridge requests per second per replica. Set to 0 to disable (unlimited). + ### --audit-logs-retention | | | diff --git a/enterprise/aibridged/aibridged_test.go b/enterprise/aibridged/aibridged_test.go index 6a84fb3841..2d74054196 100644 --- a/enterprise/aibridged/aibridged_test.go +++ b/enterprise/aibridged/aibridged_test.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/aibridge" + agplaibridge "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/aibridged" mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock" @@ -220,7 +221,7 @@ func TestExtractAuthToken(t *testing.T) { for k, v := range tc.headers { headers.Add(k, v) } - key := aibridged.ExtractAuthToken(headers) + key := agplaibridge.ExtractAuthToken(headers) require.Equal(t, tc.expectedKey, key) }) } diff --git a/enterprise/aibridged/http.go b/enterprise/aibridged/http.go index 7e41f0c007..567ff44e38 100644 --- a/enterprise/aibridged/http.go +++ b/enterprise/aibridged/http.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge" + agplaibridge "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/enterprise/aibridged/proto" ) @@ -35,7 +36,7 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) { logger := s.logger.With(slog.F("path", r.URL.Path)) - key := strings.TrimSpace(ExtractAuthToken(r.Header)) + key := strings.TrimSpace(agplaibridge.ExtractAuthToken(r.Header)) if key == "" { logger.Warn(ctx, "no auth key provided") http.Error(rw, ErrNoAuthKey.Error(), http.StatusBadRequest) @@ -79,20 +80,3 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) { handler.ServeHTTP(rw, r) } - -// ExtractAuthToken extracts authorization token from HTTP request using multiple sources. -// These sources represent the different ways clients authenticate against AI providers. -// It checks the Authorization header (Bearer token) and X-Api-Key header. -// If neither are present, an empty string is returned. -func ExtractAuthToken(header http.Header) string { - if auth := strings.TrimSpace(header.Get("Authorization")); auth != "" { - fields := strings.Fields(auth) - if len(fields) == 2 && strings.EqualFold(fields[0], "Bearer") { - return fields[1] - } - } - if apiKey := strings.TrimSpace(header.Get("X-Api-Key")); apiKey != "" { - return apiKey - } - return "" -} diff --git a/enterprise/cli/testdata/coder_server_--help.golden b/enterprise/cli/testdata/coder_server_--help.golden index 32db725d93..2a63b333a5 100644 --- a/enterprise/cli/testdata/coder_server_--help.golden +++ b/enterprise/cli/testdata/coder_server_--help.golden @@ -126,12 +126,20 @@ AI BRIDGE OPTIONS: requests (requires the "oauth2" and "mcp-server-http" experiments to be enabled). + --aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0) + Maximum number of concurrent AI Bridge requests per replica. Set to 0 + to disable (unlimited). + --aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/) The base URL of the OpenAI API. --aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY The key to authenticate against the OpenAI API. + --aibridge-rate-limit int, $CODER_AIBRIDGE_RATE_LIMIT (default: 0) + Maximum number of AI Bridge requests per second per replica. Set to 0 + to disable (unlimited). + CLIENT OPTIONS: These options change the behavior of how clients interact with the Coder. Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI. diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index 96bbe1d205..d1d12d7b02 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" @@ -23,10 +24,19 @@ import ( const ( maxListInterceptionsLimit = 1000 defaultListInterceptionsLimit = 100 + // aiBridgeRateLimitWindow is the fixed duration for rate limiting AI Bridge + // requests. This is hardcoded to keep configuration simple. + aiBridgeRateLimitWindow = time.Second ) // aibridgeHandler handles all aibridged-related endpoints. func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) { + // Build the overload protection middleware chain for the aibridged handler. + // These limits are applied per-replica. + bridgeCfg := api.DeploymentValues.AI.BridgeConfig + concurrencyLimiter := httpmw.ConcurrencyLimit(bridgeCfg.MaxConcurrency.Value(), "AI Bridge") + rateLimiter := httpmw.RateLimitByAuthToken(int(bridgeCfg.RateLimit.Value()), aiBridgeRateLimitWindow) + return func(r chi.Router) { r.Use(api.RequireFeatureMW(codersdk.FeatureAIBridge)) r.Group(func(r chi.Router) { @@ -34,25 +44,30 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f r.Get("/interceptions", api.aiBridgeListInterceptions) }) - // This is a bit funky but since aibridge only exposes a HTTP - // handler, this is how it has to be. - r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) { - if api.aibridgedHandler == nil { - httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ - Message: "aibridged handler not mounted", - }) - return - } - - // Strip either the experimental or stable prefix. - // TODO: experimental route is deprecated and must be removed with Beta. - prefixes := []string{"/api/experimental/aibridge", "/api/v2/aibridge"} - for _, prefix := range prefixes { - if strings.Contains(r.URL.String(), prefix) { - http.StripPrefix(prefix, api.aibridgedHandler).ServeHTTP(rw, r) - break + // Apply overload protection middleware to the aibridged handler. + // Concurrency limit is checked first for faster rejection under load. + r.Group(func(r chi.Router) { + r.Use(concurrencyLimiter, rateLimiter) + // This is a bit funky but since aibridge only exposes a HTTP + // handler, this is how it has to be. + r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) { + if api.aibridgedHandler == nil { + httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + Message: "aibridged handler not mounted", + }) + return } - } + + // Strip either the experimental or stable prefix. + // TODO: experimental route is deprecated and must be removed with Beta. + prefixes := []string{"/api/experimental/aibridge", "/api/v2/aibridge"} + for _, prefix := range prefixes { + if strings.Contains(r.URL.String(), prefix) { + http.StripPrefix(prefix, api.aibridgedHandler).ServeHTTP(rw, r) + break + } + } + }) }) } } diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 2913fe516a..d35b166402 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -703,3 +703,135 @@ func TestAIBridgeRouting(t *testing.T) { }) } } + +func TestAIBridgeRateLimiting(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + // Set a low rate limit for testing. + dv.AI.BridgeConfig.RateLimit = 2 + + client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + }) + t.Cleanup(func() { + _ = closer.Close() + }) + + // Register a simple test handler. + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + api.RegisterInMemoryAIBridgedHTTPHandler(testHandler) + + ctx := testutil.Context(t, testutil.WaitLong) + httpClient := &http.Client{} + url := client.URL.String() + "/api/v2/aibridge/test" + + // Make requests up to the limit - should succeed. + for range 2 { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + resp, err := httpClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + } + + // Next request should be rate limited. + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + resp, err := httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + require.NotEmpty(t, resp.Header.Get("Retry-After")) +} + +func TestAIBridgeConcurrencyLimiting(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + // Set a low concurrency limit for testing. + dv.AI.BridgeConfig.MaxConcurrency = 1 + + client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + }) + t.Cleanup(func() { + _ = closer.Close() + }) + + // Register a handler that blocks until signaled. + started := make(chan struct{}) + unblock := make(chan struct{}) + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-unblock + rw.WriteHeader(http.StatusOK) + }) + api.RegisterInMemoryAIBridgedHTTPHandler(testHandler) + + ctx := testutil.Context(t, testutil.WaitLong) + httpClient := &http.Client{} + url := client.URL.String() + "/api/v2/aibridge/test" + + // Start a request that will block. + done := make(chan struct{}) + go func() { + defer close(done) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return + } + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + resp, err := httpClient.Do(req) + if err == nil { + _ = resp.Body.Close() + } + }() + + // Wait for the first request to start processing. + select { + case <-started: + case <-ctx.Done(): + t.Fatal("timed out waiting for first request to start") + } + + // Second request should be rejected with 503. + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + resp, err := httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + + // Unblock the first request and wait for it to complete. + close(unblock) + select { + case <-done: + case <-ctx.Done(): + t.Fatal("timed out waiting for first request to complete") + } +} diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 6cb1474403..7311032a40 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -33,6 +33,8 @@ export interface AIBridgeConfig { readonly bedrock: AIBridgeBedrockConfig; readonly inject_coder_mcp_tools: boolean; readonly retention: number; + readonly max_concurrency: number; + readonly rate_limit: number; } // From codersdk/aibridge.go