mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: track credential hint across key failover attempts in aibridge (#25735)
## Problem Centralized requests recorded *the first available key from the pool at `CreateInterceptor` time* as `credential_hint`, so the interception could be persisted in the database with a hint that didn't match the key that actually served the request. The fix consists in storing, at end-of-interception, the hint of the key that succeeded, or the last attempted key if all keys are unavailable. ## Changes - Add `Key.Hint()` and update `credential_hint` on every failover attempt so it reflects the actually-used key. - Stop pre-populating `credential_hint` at `CreateInterceptor`. Centralized starts empty and is updated by the key failover loop. - Persist the final hint via `RecordInterceptionEnded`; SQL updates `credential_hint` only when `credential_kind = 'centralized'` so BYOK keeps its start-time value. - Log the actually-used hint on interception end/failure; start log uses a `<keypool-pending>` placeholder for centralized. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -2002,8 +2002,9 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: interception.ID,
|
||||
EndedAt: *endedAt,
|
||||
ID: interception.ID,
|
||||
EndedAt: *endedAt,
|
||||
CredentialHint: takeFirst(seed.CredentialHint, ""),
|
||||
})
|
||||
require.NoError(t, err, "insert aibridge interception")
|
||||
}
|
||||
|
||||
@@ -9921,8 +9921,9 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: uuid.New(),
|
||||
EndedAt: time.Now(),
|
||||
ID: uuid.New(),
|
||||
EndedAt: time.Now(),
|
||||
CredentialHint: "sk-a...efgh",
|
||||
})
|
||||
require.ErrorContains(t, err, "no rows in result set")
|
||||
require.EqualValues(t, database.AIBridgeInterception{}, got)
|
||||
@@ -9957,18 +9958,21 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
endedAt := time.Now()
|
||||
// Mark first interception as done
|
||||
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt,
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt,
|
||||
CredentialHint: "sk-a...efgh",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, updated.ID, intc0.ID)
|
||||
require.True(t, updated.EndedAt.Valid)
|
||||
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
|
||||
require.Equal(t, "sk-a...efgh", updated.CredentialHint)
|
||||
|
||||
// Updating first interception again should fail
|
||||
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt.Add(time.Hour),
|
||||
ID: intc0.ID,
|
||||
EndedAt: endedAt.Add(time.Hour),
|
||||
CredentialHint: "sk-a...efgh",
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
@@ -9979,6 +9983,52 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
require.False(t, got.EndedAt.Valid)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CentralizedHintUpdated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: uuid.New(),
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
CredentialKind: database.CredentialKindCentralized,
|
||||
CredentialHint: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc.ID,
|
||||
EndedAt: time.Now(),
|
||||
CredentialHint: "sk-a...efgh",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sk-a...efgh", updated.CredentialHint)
|
||||
})
|
||||
|
||||
t.Run("BYOKHintPreserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: uuid.New(),
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
CredentialKind: database.CredentialKindByok,
|
||||
CredentialHint: "sk-u...byok",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intc.ID,
|
||||
EndedAt: time.Now(),
|
||||
CredentialHint: "sk-a...efgh",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sk-u...byok", updated.CredentialHint)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
|
||||
Generated
+13
-5
@@ -2389,20 +2389,28 @@ func (q *sqlQuerier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Contex
|
||||
|
||||
const updateAIBridgeInterceptionEnded = `-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = $1::timestamptz
|
||||
SET ended_at = $1::timestamptz,
|
||||
-- BYOK records its hint at the start of the interception.
|
||||
-- Centralized uses key failover, so its hint is only known
|
||||
-- at end-of-interception.
|
||||
credential_hint = CASE
|
||||
WHEN credential_kind = 'centralized' THEN $2::text
|
||||
ELSE credential_hint
|
||||
END
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
id = $3::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id, provider_name, credential_kind, credential_hint
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
CredentialHint string `db:"credential_hint" json:"credential_hint"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.ID)
|
||||
row := q.db.QueryRowContext(ctx, updateAIBridgeInterceptionEnded, arg.EndedAt, arg.CredentialHint, arg.ID)
|
||||
var i AIBridgeInterception
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
|
||||
@@ -8,7 +8,14 @@ RETURNING *;
|
||||
|
||||
-- name: UpdateAIBridgeInterceptionEnded :one
|
||||
UPDATE aibridge_interceptions
|
||||
SET ended_at = @ended_at::timestamptz
|
||||
SET ended_at = @ended_at::timestamptz,
|
||||
-- BYOK records its hint at the start of the interception.
|
||||
-- Centralized uses key failover, so its hint is only known
|
||||
-- at end-of-interception.
|
||||
credential_hint = CASE
|
||||
WHEN credential_kind = 'centralized' THEN @credential_hint::text
|
||||
ELSE credential_hint
|
||||
END
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND ended_at IS NULL
|
||||
|
||||
Reference in New Issue
Block a user