Files
coder/enterprise/coderd/aibridge.go
T
Paweł Banaszewski 90c11f3386 feat: add client column to aibridge_interceptions table (#21839)
Adds `client` column to `aibridge_interceptions` table. It is set accordingly to what is passed from AI Bridge in `RecordInterception`.
Adds interception filtering by `client` value.

Depends on: https://github.com/coder/aibridge/pull/158
Updates aibridge library to include this change.

Fixes: https://github.com/coder/aibridge/issues/31
2026-02-17 15:43:02 +01:00

225 lines
8.3 KiB
Go

package coderd
import (
"context"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/codersdk"
)
const (
maxListInterceptionsLimit = 1000
defaultListInterceptionsLimit = 100
// aiBridgeRateLimitWindow is the fixed duration for rate limiting AI Bridge
// requests. This is hardcoded to keep configuration simple.
aiBridgeRateLimitWindow = time.Second
)
// aibridgeHandler handles all aibridged-related endpoints.
func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) {
// Build the overload protection middleware chain for the aibridged handler.
// These limits are applied per-replica.
bridgeCfg := api.DeploymentValues.AI.BridgeConfig
concurrencyLimiter := httpmw.ConcurrencyLimit(bridgeCfg.MaxConcurrency.Value(), "AI Bridge")
rateLimiter := httpmw.RateLimitByAuthToken(int(bridgeCfg.RateLimit.Value()), aiBridgeRateLimitWindow)
return func(r chi.Router) {
r.Use(api.RequireFeatureMW(codersdk.FeatureAIBridge))
r.Group(func(r chi.Router) {
r.Use(middlewares...)
r.Get("/interceptions", api.aiBridgeListInterceptions)
})
// Apply overload protection middleware to the aibridged handler.
// Concurrency limit is checked first for faster rejection under load.
r.Group(func(r chi.Router) {
r.Use(concurrencyLimiter, rateLimiter)
// This is a bit funky but since aibridge only exposes a HTTP
// handler, this is how it has to be.
r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) {
if api.aibridgedHandler == nil {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "aibridged handler not mounted",
})
return
}
http.StripPrefix("/api/v2/aibridge", api.aibridgedHandler).ServeHTTP(rw, r)
})
})
}
}
// aiBridgeListInterceptions returns all AI Bridge interceptions a user can read.
// Optional filters with query params
//
// @Summary List AI Bridge interceptions
// @ID list-ai-bridge-interceptions
// @Security CoderSessionToken
// @Produce json
// @Tags AI Bridge
// @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before."
// @Param limit query int false "Page limit"
// @Param after_id query string false "Cursor pagination after ID (cannot be used with offset)"
// @Param offset query int false "Offset pagination (cannot be used with after_id)"
// @Success 200 {object} codersdk.AIBridgeListInterceptionsResponse
// @Router /aibridge/interceptions [get]
func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
page, ok := coderd.ParsePagination(rw, r)
if !ok {
return
}
if page.AfterID != uuid.Nil && page.Offset != 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Detail: "Cannot use both after_id and offset pagination in the same request.",
})
return
}
if page.Limit == 0 {
page.Limit = defaultListInterceptionsLimit
}
if page.Limit > maxListInterceptionsLimit || page.Limit < 1 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid pagination limit value.",
Detail: fmt.Sprintf("Pagination limit must be in range (0, %d]", maxListInterceptionsLimit),
})
return
}
queryStr := r.URL.Query().Get("q")
filter, errs := searchquery.AIBridgeInterceptions(ctx, api.Database, queryStr, page, apiKey.UserID)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid workspace search query.",
Validations: errs,
})
return
}
var (
count int64
rows []database.ListAIBridgeInterceptionsRow
)
err := api.Database.InTx(func(db database.Store) error {
// Ensure the after_id interception exists and is visible to the user.
if page.AfterID != uuid.Nil {
_, err := db.GetAIBridgeInterceptionByID(ctx, page.AfterID)
if err != nil {
return xerrors.Errorf("get aibridge interception by id %s for cursor pagination: %w", page.AfterID, err)
}
}
var err error
// Get the full count of authorized interceptions matching the filter
// for pagination purposes.
count, err = db.CountAIBridgeInterceptions(ctx, database.CountAIBridgeInterceptionsParams{
StartedAfter: filter.StartedAfter,
StartedBefore: filter.StartedBefore,
InitiatorID: filter.InitiatorID,
Provider: filter.Provider,
Model: filter.Model,
Client: filter.Client,
})
if err != nil {
return xerrors.Errorf("count authorized aibridge interceptions: %w", err)
}
// This only returns authorized interceptions (when using dbauthz).
rows, err = db.ListAIBridgeInterceptions(ctx, filter)
if err != nil {
return xerrors.Errorf("list aibridge interceptions: %w", err)
}
return nil
}, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error getting AI Bridge interceptions.",
Detail: err.Error(),
})
return
}
// This fetches the other rows associated with the interceptions.
items, err := populatedAndConvertAIBridgeInterceptions(ctx, api.Database, rows)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error converting database rows to API response.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AIBridgeListInterceptionsResponse{
Count: count,
Results: items,
})
}
func populatedAndConvertAIBridgeInterceptions(ctx context.Context, db database.Store, dbInterceptions []database.ListAIBridgeInterceptionsRow) ([]codersdk.AIBridgeInterception, error) {
ids := make([]uuid.UUID, len(dbInterceptions))
for i, row := range dbInterceptions {
ids[i] = row.AIBridgeInterception.ID
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AI Bridge interception subresources use the same authorization call as their parent.
tokenUsagesRows, err := db.ListAIBridgeTokenUsagesByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge token usages from database: %w", err)
}
tokenUsagesMap := make(map[uuid.UUID][]database.AIBridgeTokenUsage, len(dbInterceptions))
for _, row := range tokenUsagesRows {
tokenUsagesMap[row.InterceptionID] = append(tokenUsagesMap[row.InterceptionID], row)
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AI Bridge interception subresources use the same authorization call as their parent.
userPromptRows, err := db.ListAIBridgeUserPromptsByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge user prompts from database: %w", err)
}
userPromptsMap := make(map[uuid.UUID][]database.AIBridgeUserPrompt, len(dbInterceptions))
for _, row := range userPromptRows {
userPromptsMap[row.InterceptionID] = append(userPromptsMap[row.InterceptionID], row)
}
//nolint:gocritic // This is a system function until we implement a join for aibridge interceptions. AI Bridge interception subresources use the same authorization call as their parent.
toolUsagesRows, err := db.ListAIBridgeToolUsagesByInterceptionIDs(dbauthz.AsSystemRestricted(ctx), ids)
if err != nil {
return nil, xerrors.Errorf("get linked aibridge tool usages from database: %w", err)
}
toolUsagesMap := make(map[uuid.UUID][]database.AIBridgeToolUsage, len(dbInterceptions))
for _, row := range toolUsagesRows {
toolUsagesMap[row.InterceptionID] = append(toolUsagesMap[row.InterceptionID], row)
}
items := make([]codersdk.AIBridgeInterception, len(dbInterceptions))
for i, row := range dbInterceptions {
items[i] = db2sdk.AIBridgeInterception(
row.AIBridgeInterception,
row.VisibleUser,
tokenUsagesMap[row.AIBridgeInterception.ID],
userPromptsMap[row.AIBridgeInterception.ID],
toolUsagesMap[row.AIBridgeInterception.ID],
)
}
return items, nil
}