feat: filter interceptions and sessions by provider name (#25640)

Allows filtering sessions & interceptions by provider name, and adds a test to vaidate that provider name is immutable (at least until #25606 lands).
This commit is contained in:
Danny Kopping
2026-05-25 16:31:48 +02:00
committed by GitHub
parent c8359d8598
commit 4ddda3a9db
16 changed files with 217 additions and 73 deletions
+44
View File
@@ -327,6 +327,50 @@ func TestAIProvidersCRUD(t *testing.T) {
require.Contains(t, sdkErr.Message, "At least one field must be provided")
})
t.Run("UpdateCannotMutateName", func(t *testing.T) {
t.Parallel()
// ai_providers.name is the stable key that aibridge_interceptions
// snapshots into provider_name. Renames would silently desync
// historical interceptions from their live row and break the
// future FK backfill, so the PATCH endpoint must ignore any "name"
// field in the payload. The SDK type intentionally has no Name
// field; this test sends raw JSON to defend against a future
// regression where someone adds one without thinking.
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:gocritic // Owner role is the audience for this endpoint.
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "stable-name",
Enabled: true,
BaseURL: "https://api.openai.com/v1",
})
require.NoError(t, err)
res, err := client.Request(ctx, http.MethodPatch,
"/api/v2/ai/providers/"+created.Name,
json.RawMessage(`{"name":"renamed","display_name":"New Display"}`),
)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
got, err := client.AIProvider(ctx, created.Name)
require.NoError(t, err)
require.Equal(t, "stable-name", got.Name, "name must not be mutable via PATCH")
require.Equal(t, "New Display", got.DisplayName, "display_name should still update")
// Confirm the original name still resolves and the attempted new
// name does not exist as a separate row.
_, err = client.AIProvider(ctx, "renamed")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("UpdateSettingsEmptyObjectRejected", func(t *testing.T) {
t.Parallel()
// "settings": {} cannot decode because the _type discriminator
+2 -2
View File
@@ -1436,7 +1436,7 @@ const docTemplate = `{
"parameters": [
{
"type": "string",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, started_after, started_before.",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, provider_name, model, started_after, started_before.",
"name": "q",
"in": "query"
},
@@ -1515,7 +1515,7 @@ const docTemplate = `{
"parameters": [
{
"type": "string",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, provider_name, model, client, session_id, started_after, started_before.",
"name": "q",
"in": "query"
},
+2 -2
View File
@@ -1265,7 +1265,7 @@
"parameters": [
{
"type": "string",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before.",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, provider_name, model, started_after, started_before.",
"name": "q",
"in": "query"
},
@@ -1336,7 +1336,7 @@
"parameters": [
{
"type": "string",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, provider_name, model, client, session_id, started_after, started_before.",
"name": "q",
"in": "query"
},
+4
View File
@@ -932,6 +932,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.AfterID,
@@ -998,6 +999,7 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, a
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
)
@@ -1097,6 +1099,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.SessionID,
@@ -1161,6 +1164,7 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg Co
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.SessionID,
+45 -17
View File
@@ -892,14 +892,19 @@ WHERE
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
ELSE true
END
-- Filter model
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
ELSE true
END
-- Filter client
AND CASE
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
ELSE true
END
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions
@@ -911,6 +916,7 @@ type CountAIBridgeInterceptionsParams struct {
StartedBefore time.Time `db:"started_before" json:"started_before"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
ProviderName string `db:"provider_name" json:"provider_name"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
}
@@ -921,6 +927,7 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
)
@@ -956,19 +963,24 @@ WHERE
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
ELSE true
END
-- Filter model
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
ELSE true
END
-- Filter client
AND CASE
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
ELSE true
END
-- Filter session_id
AND CASE
WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text
WHEN $8::text != '' THEN aibridge_interceptions.session_id = $8::text
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
@@ -980,6 +992,7 @@ type CountAIBridgeSessionsParams struct {
StartedBefore time.Time `db:"started_before" json:"started_before"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
ProviderName string `db:"provider_name" json:"provider_name"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
SessionID string `db:"session_id" json:"session_id"`
@@ -991,6 +1004,7 @@ func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridg
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.SessionID,
@@ -1611,19 +1625,24 @@ WHERE
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
ELSE true
END
-- Filter model
AND CASE
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
ELSE true
END
-- Filter client
AND CASE
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
ELSE true
END
-- Cursor pagination
AND CASE
WHEN $7::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
WHEN $8::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the started_at field, so select all
-- rows before the cursor and before the after_id UUID.
@@ -1631,8 +1650,8 @@ WHERE
-- "after_id" terminology comes from our pagination parser in
-- coderd.
(aibridge_interceptions.started_at, aibridge_interceptions.id) < (
(SELECT started_at FROM aibridge_interceptions WHERE id = $7),
$7::uuid
(SELECT started_at FROM aibridge_interceptions WHERE id = $8),
$8::uuid
)
)
ELSE true
@@ -1642,8 +1661,8 @@ WHERE
ORDER BY
aibridge_interceptions.started_at DESC,
aibridge_interceptions.id DESC
LIMIT COALESCE(NULLIF($9::integer, 0), 100)
OFFSET $8
LIMIT COALESCE(NULLIF($10::integer, 0), 100)
OFFSET $9
`
type ListAIBridgeInterceptionsParams struct {
@@ -1651,6 +1670,7 @@ type ListAIBridgeInterceptionsParams struct {
StartedBefore time.Time `db:"started_before" json:"started_before"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
ProviderName string `db:"provider_name" json:"provider_name"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
AfterID uuid.UUID `db:"after_id" json:"after_id"`
@@ -1669,6 +1689,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.AfterID,
@@ -2033,19 +2054,24 @@ session_page AS (
WHEN $5::text != '' THEN ai.provider = $5::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN $6::text != '' THEN ai.provider_name = $6::text
ELSE true
END
-- Filter model
AND CASE
WHEN $6::text != '' THEN ai.model = $6::text
WHEN $7::text != '' THEN ai.model = $7::text
ELSE true
END
-- Filter client
AND CASE
WHEN $7::text != '' THEN COALESCE(ai.client, 'Unknown') = $7::text
WHEN $8::text != '' THEN COALESCE(ai.client, 'Unknown') = $8::text
ELSE true
END
-- Filter session_id
AND CASE
WHEN $8::text != '' THEN ai.session_id = $8::text
WHEN $9::text != '' THEN ai.session_id = $9::text
ELSE true
END
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
@@ -2069,8 +2095,8 @@ session_page AS (
ORDER BY
last_active_at DESC,
ai.session_id DESC
LIMIT COALESCE(NULLIF($10::integer, 0), 100)
OFFSET $9
LIMIT COALESCE(NULLIF($11::integer, 0), 100)
OFFSET $10
)
SELECT
sp.session_id,
@@ -2137,6 +2163,7 @@ type ListAIBridgeSessionsParams struct {
StartedBefore time.Time `db:"started_before" json:"started_before"`
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
Provider string `db:"provider" json:"provider"`
ProviderName string `db:"provider_name" json:"provider_name"`
Model string `db:"model" json:"model"`
Client string `db:"client" json:"client"`
SessionID string `db:"session_id" json:"session_id"`
@@ -2179,6 +2206,7 @@ func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeS
arg.StartedBefore,
arg.InitiatorID,
arg.Provider,
arg.ProviderName,
arg.Model,
arg.Client,
arg.SessionID,
+20
View File
@@ -133,6 +133,11 @@ WHERE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
@@ -177,6 +182,11 @@ WHERE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
@@ -418,6 +428,11 @@ WHERE
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
@@ -505,6 +520,11 @@ session_page AS (
WHEN @provider::text != '' THEN ai.provider = @provider::text
ELSE true
END
-- Filter provider_name
AND CASE
WHEN @provider_name::text != '' THEN ai.provider_name = @provider_name::text
ELSE true
END
-- Filter model
AND CASE
WHEN @model::text != '' THEN ai.model = @model::text
+20
View File
@@ -387,6 +387,7 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
parser := httpapi.NewQueryParamParser()
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
filter.Provider = parser.String(values, "", "provider")
filter.ProviderName = parseAIProviderName(ctx, db, parser, values)
filter.Model = parser.String(values, "", "model")
filter.Client = parser.String(values, "", "client")
@@ -429,6 +430,7 @@ func AIBridgeSessions(ctx context.Context, db database.Store, query string, page
parser := httpapi.NewQueryParamParser()
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
filter.Provider = parser.String(values, "", "provider")
filter.ProviderName = parseAIProviderName(ctx, db, parser, values)
filter.Model = parser.String(values, "", "model")
filter.Client = parser.String(values, "", "client")
filter.SessionID = parser.String(values, "", "session_id")
@@ -700,6 +702,24 @@ func parseOrganization(ctx context.Context, db database.Store, parser *httpapi.Q
})
}
// parseAIProviderName resolves a "provider_name" filter param against
// ai_providers.name. Unknown names produce a validation error so typos
// surface immediately rather than returning a silently-empty result set.
func parseAIProviderName(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values) string {
name := parser.String(vals, "", "provider_name")
if name == "" {
return ""
}
if _, err := db.GetAIProviderByName(ctx, name); err != nil {
parser.Errors = append(parser.Errors, codersdk.ValidationError{
Field: "provider_name",
Detail: `Query param "provider_name" has invalid value: provider not found or unauthorized`,
})
return ""
}
return name
}
func parseUser(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values, queryParam string, actorID uuid.UUID) uuid.UUID {
return httpapi.ParseCustom(parser, vals, uuid.Nil, queryParam, func(v string) (uuid.UUID, error) {
if v == "" {