mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: filter interceptions and sessions by provider name (#25640)
Allows filtering sessions & interceptions by provider name, and adds a test to vaidate that provider name is immutable (at least until #25606 lands).
This commit is contained in:
@@ -327,6 +327,50 @@ func TestAIProvidersCRUD(t *testing.T) {
|
||||
require.Contains(t, sdkErr.Message, "At least one field must be provided")
|
||||
})
|
||||
|
||||
t.Run("UpdateCannotMutateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// ai_providers.name is the stable key that aibridge_interceptions
|
||||
// snapshots into provider_name. Renames would silently desync
|
||||
// historical interceptions from their live row and break the
|
||||
// future FK backfill, so the PATCH endpoint must ignore any "name"
|
||||
// field in the payload. The SDK type intentionally has no Name
|
||||
// field; this test sends raw JSON to defend against a future
|
||||
// regression where someone adds one without thinking.
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:gocritic // Owner role is the audience for this endpoint.
|
||||
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "stable-name",
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodPatch,
|
||||
"/api/v2/ai/providers/"+created.Name,
|
||||
json.RawMessage(`{"name":"renamed","display_name":"New Display"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
got, err := client.AIProvider(ctx, created.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "stable-name", got.Name, "name must not be mutable via PATCH")
|
||||
require.Equal(t, "New Display", got.DisplayName, "display_name should still update")
|
||||
|
||||
// Confirm the original name still resolves and the attempted new
|
||||
// name does not exist as a separate row.
|
||||
_, err = client.AIProvider(ctx, "renamed")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("UpdateSettingsEmptyObjectRejected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// "settings": {} cannot decode because the _type discriminator
|
||||
|
||||
Generated
+2
-2
@@ -1436,7 +1436,7 @@ const docTemplate = `{
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, started_after, started_before.",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, provider_name, model, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
@@ -1515,7 +1515,7 @@ const docTemplate = `{
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, provider_name, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
|
||||
Generated
+2
-2
@@ -1265,7 +1265,7 @@
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before.",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, provider_name, model, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
@@ -1336,7 +1336,7 @@
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, provider_name, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
|
||||
@@ -932,6 +932,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.AfterID,
|
||||
@@ -998,6 +999,7 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, a
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
)
|
||||
@@ -1097,6 +1099,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
@@ -1161,6 +1164,7 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg Co
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
|
||||
@@ -892,14 +892,19 @@ WHERE
|
||||
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
|
||||
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
|
||||
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions
|
||||
@@ -911,6 +916,7 @@ type CountAIBridgeInterceptionsParams struct {
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
}
|
||||
@@ -921,6 +927,7 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
)
|
||||
@@ -956,19 +963,24 @@ WHERE
|
||||
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
|
||||
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
|
||||
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text
|
||||
WHEN $8::text != '' THEN aibridge_interceptions.session_id = $8::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
@@ -980,6 +992,7 @@ type CountAIBridgeSessionsParams struct {
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
@@ -991,6 +1004,7 @@ func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridg
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
@@ -1611,19 +1625,24 @@ WHERE
|
||||
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.provider_name = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
|
||||
WHEN $6::text != '' THEN aibridge_interceptions.model = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
|
||||
WHEN $7::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Cursor pagination
|
||||
AND CASE
|
||||
WHEN $7::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
WHEN $8::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the started_at field, so select all
|
||||
-- rows before the cursor and before the after_id UUID.
|
||||
@@ -1631,8 +1650,8 @@ WHERE
|
||||
-- "after_id" terminology comes from our pagination parser in
|
||||
-- coderd.
|
||||
(aibridge_interceptions.started_at, aibridge_interceptions.id) < (
|
||||
(SELECT started_at FROM aibridge_interceptions WHERE id = $7),
|
||||
$7::uuid
|
||||
(SELECT started_at FROM aibridge_interceptions WHERE id = $8),
|
||||
$8::uuid
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -1642,8 +1661,8 @@ WHERE
|
||||
ORDER BY
|
||||
aibridge_interceptions.started_at DESC,
|
||||
aibridge_interceptions.id DESC
|
||||
LIMIT COALESCE(NULLIF($9::integer, 0), 100)
|
||||
OFFSET $8
|
||||
LIMIT COALESCE(NULLIF($10::integer, 0), 100)
|
||||
OFFSET $9
|
||||
`
|
||||
|
||||
type ListAIBridgeInterceptionsParams struct {
|
||||
@@ -1651,6 +1670,7 @@ type ListAIBridgeInterceptionsParams struct {
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
@@ -1669,6 +1689,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.AfterID,
|
||||
@@ -2033,19 +2054,24 @@ session_page AS (
|
||||
WHEN $5::text != '' THEN ai.provider = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN ai.provider_name = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN ai.model = $6::text
|
||||
WHEN $7::text != '' THEN ai.model = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN COALESCE(ai.client, 'Unknown') = $7::text
|
||||
WHEN $8::text != '' THEN COALESCE(ai.client, 'Unknown') = $8::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN ai.session_id = $8::text
|
||||
WHEN $9::text != '' THEN ai.session_id = $9::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
@@ -2069,8 +2095,8 @@ session_page AS (
|
||||
ORDER BY
|
||||
last_active_at DESC,
|
||||
ai.session_id DESC
|
||||
LIMIT COALESCE(NULLIF($10::integer, 0), 100)
|
||||
OFFSET $9
|
||||
LIMIT COALESCE(NULLIF($11::integer, 0), 100)
|
||||
OFFSET $10
|
||||
)
|
||||
SELECT
|
||||
sp.session_id,
|
||||
@@ -2137,6 +2163,7 @@ type ListAIBridgeSessionsParams struct {
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
@@ -2179,6 +2206,7 @@ func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeS
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.ProviderName,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
|
||||
@@ -133,6 +133,11 @@ WHERE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
@@ -177,6 +182,11 @@ WHERE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
@@ -418,6 +428,11 @@ WHERE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN @provider_name::text != '' THEN aibridge_interceptions.provider_name = @provider_name::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
@@ -505,6 +520,11 @@ session_page AS (
|
||||
WHEN @provider::text != '' THEN ai.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider_name
|
||||
AND CASE
|
||||
WHEN @provider_name::text != '' THEN ai.provider_name = @provider_name::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN ai.model = @model::text
|
||||
|
||||
@@ -387,6 +387,7 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
|
||||
filter.Provider = parser.String(values, "", "provider")
|
||||
filter.ProviderName = parseAIProviderName(ctx, db, parser, values)
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
filter.Client = parser.String(values, "", "client")
|
||||
|
||||
@@ -429,6 +430,7 @@ func AIBridgeSessions(ctx context.Context, db database.Store, query string, page
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
|
||||
filter.Provider = parser.String(values, "", "provider")
|
||||
filter.ProviderName = parseAIProviderName(ctx, db, parser, values)
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
filter.Client = parser.String(values, "", "client")
|
||||
filter.SessionID = parser.String(values, "", "session_id")
|
||||
@@ -700,6 +702,24 @@ func parseOrganization(ctx context.Context, db database.Store, parser *httpapi.Q
|
||||
})
|
||||
}
|
||||
|
||||
// parseAIProviderName resolves a "provider_name" filter param against
|
||||
// ai_providers.name. Unknown names produce a validation error so typos
|
||||
// surface immediately rather than returning a silently-empty result set.
|
||||
func parseAIProviderName(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values) string {
|
||||
name := parser.String(vals, "", "provider_name")
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
if _, err := db.GetAIProviderByName(ctx, name); err != nil {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: "provider_name",
|
||||
Detail: `Query param "provider_name" has invalid value: provider not found or unauthorized`,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func parseUser(ctx context.Context, db database.Store, parser *httpapi.QueryParamParser, vals url.Values, queryParam string, actorID uuid.UUID) uuid.UUID {
|
||||
return httpapi.ParseCustom(parser, vals, uuid.Nil, queryParam, func(v string) (uuid.UUID, error) {
|
||||
if v == "" {
|
||||
|
||||
Reference in New Issue
Block a user