mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
7b903cad73
## Problem Centralized requests recorded *the first available key from the pool at `CreateInterceptor` time* as `credential_hint`, so the interception could be persisted in the database with a hint that didn't match the key that actually served the request. The fix consists in storing, at end-of-interception, the hint of the key that succeeded, or the last attempted key if all keys are unavailable. ## Changes - Add `Key.Hint()` and update `credential_hint` on every failover attempt so it reflects the actually-used key. - Stop pre-populating `credential_hint` at `CreateInterceptor`. Centralized starts empty and is updated by the key failover loop. - Persist the final hint via `RecordInterceptionEnded`; SQL updates `credential_hint` only when `credential_kind = 'centralized'` so BYOK keeps its start-time value. - Log the actually-used hint on interception end/failure; start log uses a `<keypool-pending>` placeholder for centralized. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
160 lines
5.0 KiB
Go
160 lines
5.0 KiB
Go
package aibridged
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/coder/coder/v2/aibridge"
|
|
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
)
|
|
|
|
var _ aibridge.Recorder = &recorderTranslation{}
|
|
|
|
// recorderTranslation satisfies the aibridge.Recorder interface and translates calls into dRPC calls to aibridgedserver.
|
|
type recorderTranslation struct {
|
|
apiKeyID string
|
|
client proto.DRPCRecorderClient
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error {
|
|
_, err := t.client.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
|
Id: req.ID,
|
|
ApiKeyId: t.apiKeyID,
|
|
InitiatorId: req.InitiatorID,
|
|
Provider: req.Provider,
|
|
ProviderName: req.ProviderName,
|
|
Model: req.Model,
|
|
UserAgent: req.UserAgent,
|
|
Client: req.Client,
|
|
ClientSessionId: req.ClientSessionID,
|
|
Metadata: marshalForProto(req.Metadata),
|
|
StartedAt: timestamppb.New(req.StartedAt),
|
|
CorrelatingToolCallId: req.CorrelatingToolCallID,
|
|
CredentialKind: req.CredentialKind,
|
|
CredentialHint: req.CredentialHint,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error {
|
|
_, err := t.client.RecordInterceptionEnded(ctx, &proto.RecordInterceptionEndedRequest{
|
|
Id: req.ID,
|
|
EndedAt: timestamppb.New(req.EndedAt),
|
|
CredentialHint: req.CredentialHint,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error {
|
|
_, err := t.client.RecordPromptUsage(ctx, &proto.RecordPromptUsageRequest{
|
|
InterceptionId: req.InterceptionID,
|
|
MsgId: req.MsgID,
|
|
Prompt: req.Prompt,
|
|
Metadata: marshalForProto(req.Metadata),
|
|
CreatedAt: timestamppb.New(req.CreatedAt),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordTokenUsage(ctx context.Context, req *aibridge.TokenUsageRecord) error {
|
|
merged := req.Metadata
|
|
if merged == nil {
|
|
merged = aibridge.Metadata{}
|
|
}
|
|
|
|
// Merge remaining extra token types into metadata.
|
|
for k, v := range req.ExtraTokenTypes {
|
|
merged[k] = v
|
|
}
|
|
|
|
_, err := t.client.RecordTokenUsage(ctx, &proto.RecordTokenUsageRequest{
|
|
InterceptionId: req.InterceptionID,
|
|
MsgId: req.MsgID,
|
|
InputTokens: req.Input,
|
|
OutputTokens: req.Output,
|
|
CacheReadInputTokens: req.CacheReadInputTokens,
|
|
CacheWriteInputTokens: req.CacheWriteInputTokens,
|
|
Metadata: marshalForProto(merged),
|
|
CreatedAt: timestamppb.New(req.CreatedAt),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordToolUsage(ctx context.Context, req *aibridge.ToolUsageRecord) error {
|
|
serialized, err := json.Marshal(req.Args)
|
|
if err != nil {
|
|
return xerrors.Errorf("serialize tool %q args: %w", req.Tool, err)
|
|
}
|
|
|
|
var invErr *string
|
|
if req.InvocationError != nil {
|
|
invErr = ptr.Ref(req.InvocationError.Error())
|
|
}
|
|
|
|
_, err = t.client.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
|
|
InterceptionId: req.InterceptionID,
|
|
MsgId: req.MsgID,
|
|
ToolCallId: req.ToolCallID,
|
|
ServerUrl: req.ServerURL,
|
|
Tool: req.Tool,
|
|
Input: string(serialized),
|
|
Injected: req.Injected,
|
|
InvocationError: invErr,
|
|
Metadata: marshalForProto(req.Metadata),
|
|
CreatedAt: timestamppb.New(req.CreatedAt),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t *recorderTranslation) RecordModelThought(ctx context.Context, req *aibridge.ModelThoughtRecord) error {
|
|
_, err := t.client.RecordModelThought(ctx, &proto.RecordModelThoughtRequest{
|
|
InterceptionId: req.InterceptionID,
|
|
Content: req.Content,
|
|
Metadata: marshalForProto(req.Metadata),
|
|
CreatedAt: timestamppb.New(req.CreatedAt),
|
|
})
|
|
return err
|
|
}
|
|
|
|
// marshalForProto will attempt to convert from aibridge.Metadata into a proto-friendly map[string]*anypb.Any.
|
|
// If any marshaling fails, rather return a map with the error details since we don't want to fail Record* funcs if metadata can't encode,
|
|
// since it's, well, metadata.
|
|
func marshalForProto(in aibridge.Metadata) map[string]*anypb.Any {
|
|
out := make(map[string]*anypb.Any, len(in))
|
|
if len(in) == 0 {
|
|
return out
|
|
}
|
|
|
|
// Instead of returning error, just encode error into metadata.
|
|
encodeErr := func(err error) map[string]*anypb.Any {
|
|
errVal, _ := anypb.New(structpb.NewStringValue(err.Error()))
|
|
mdVal, _ := anypb.New(structpb.NewStringValue(fmt.Sprintf("%+v", in)))
|
|
return map[string]*anypb.Any{
|
|
"error": errVal,
|
|
"metadata": mdVal,
|
|
}
|
|
}
|
|
|
|
for k, v := range in {
|
|
sv, err := structpb.NewValue(v)
|
|
if err != nil {
|
|
return encodeErr(err)
|
|
}
|
|
|
|
av, err := anypb.New(sv)
|
|
if err != nil {
|
|
return encodeErr(err)
|
|
}
|
|
|
|
out[k] = av
|
|
}
|
|
return out
|
|
}
|