Files
coder/coderd/x/chatd/model_routing_internal_test.go
T

750 lines
28 KiB
Go

package chatd
import (
"database/sql"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
)
type aibridgeTestFactory struct {
providerName string
source aibridge.Source
err error
rt http.RoundTripper
}
func (f *aibridgeTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
f.providerName = providerName
f.source = source
if f.err != nil {
return nil, f.err
}
return f.rt, nil
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func aibridgeTestFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
var ptr atomic.Pointer[aibridge.TransportFactory]
ptr.Store(&factory)
return &ptr
}
func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerType database.AIProviderType) database.AIProvider {
return database.AIProvider{
ID: providerID,
Name: providerName,
Type: providerType,
Enabled: true,
}
}
func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute {
return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{})
}
func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest {
return modelClientRequest{
Chat: chat,
ModelName: model,
UserAgent: chatprovider.UserAgent(),
}
}
func TestAIBridgeProviderFormatMapping(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerType database.AIProviderType
wantProvider string
wantBaseURL string
}{
{name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"},
{name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
{name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
{name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
config := fantasyConfigForAIBridge(tt.providerType)
require.Equal(t, tt.wantProvider, config.ProviderHint)
require.Equal(t, tt.wantBaseURL, config.Keys.BaseURL(config.ProviderHint))
require.Equal(t, aibridgePlaceholderAPIKey, config.Keys.APIKey(config.ProviderHint))
})
}
}
func TestResolveModelRouteForConfigPreservesBaseURL(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
providerID := uuid.New()
baseURL := "https://openai.example.com/v1"
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Name: "primary-openai",
Enabled: true,
BaseUrl: baseURL,
}, nil)
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
ProviderID: providerID,
APIKey: "provider-key",
}}, nil)
server := &Server{db: db}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindDirect, route.kind)
require.Equal(t, "openai", route.direct.ProviderHint)
require.Equal(t, "provider-key", route.direct.Keys.APIKey("openai"))
require.Equal(t, baseURL, route.direct.Keys.BaseURL("openai"))
}
func TestAIGatewayProviderAuthForUser(t *testing.T) {
t.Parallel()
ctx := t.Context()
ownerID := uuid.New()
providerID := uuid.New()
provider := database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true}
t.Run("OpenAIUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Equal(t, "Bearer sk-user", auth.Headers["Authorization"])
require.Empty(t, auth.Headers["X-Api-Key"])
})
t.Run("AnthropicUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatAnthropic)
require.NoError(t, err)
require.Equal(t, "sk-user", auth.Headers["X-Api-Key"])
require.Empty(t, auth.Headers["Authorization"])
})
t.Run("NoUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{}, sql.ErrNoRows)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Empty(t, auth.Headers)
})
t.Run("BYOKDisabled", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := &Server{db: db, allowBYOK: false}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Empty(t, auth.Headers)
})
}
func TestAIGatewayProviderAuthRedactsFormatting(t *testing.T) {
t.Parallel()
auth := aiGatewayProviderAuth{Headers: map[string]string{
"Authorization": "Bearer sk-user",
"X-Api-Key": "sk-user",
}}
for _, formatted := range []string{
fmt.Sprint(auth),
fmt.Sprintf("%+v", auth),
fmt.Sprintf("%#v", auth),
} {
require.NotContains(t, formatted, "sk-user")
require.NotContains(t, formatted, "Bearer sk-user")
require.Contains(t, formatted, "redacted")
}
}
func TestResolveModelRouteForConfigAIGatewayProviderAuth(t *testing.T) {
t.Parallel()
ctx := t.Context()
ownerID := uuid.New()
providerID := uuid.New()
provider := database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Name: "primary-openai",
Enabled: true,
}
modelConfig := database.ChatModelConfig{
ID: uuid.New(),
Model: "gpt-4",
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}
t.Run("UserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: true}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindAIGateway, route.kind)
require.Equal(t, "Bearer sk-user", route.aiGateway.ProviderAuth.Headers["Authorization"])
})
t.Run("CentralProviderCredentialsNotForwarded", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: false}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindAIGateway, route.kind)
require.Empty(t, route.aiGateway.ProviderAuth.Headers)
})
}
func TestAIGatewayModelForwardsProviderAuth(t *testing.T) {
t.Parallel()
type seenRequest struct {
authorization string
xAPIKey string
coderToken string
apiKeyID string
path string
}
newServer := func(t *testing.T, provider database.AIProvider, auth aiGatewayProviderAuth, seen chan seenRequest) (*Server, resolvedModelRoute) {
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
seen <- seenRequest{
authorization: req.Header.Get("Authorization"),
xAPIKey: req.Header.Get("X-Api-Key"),
coderToken: req.Header.Get(aibridge.HeaderCoderToken),
apiKeyID: apiKeyID,
path: req.URL.Path,
}
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
if provider.Type == database.AiProviderTypeAnthropic {
body = `{"id":"msg_test","type":"message","role":"assistant","model":"claude-haiku-4-5","content":[{"type":"text","text":"hello"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":1,"output_tokens":1}}`
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
})}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
route := newAIGatewayModelRoute(provider, string(provider.Type), auth)
return server, route
}
t.Run("OpenAI", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
server, route := newServer(t, provider, aiGatewayProviderAuth{
Headers: map[string]string{"Authorization": "Bearer sk-user"},
}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "Bearer sk-user", got.authorization)
require.Empty(t, got.xAPIKey)
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
require.Equal(t, "/v1/responses", got.path)
})
t.Run("Anthropic", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-anthropic", database.AiProviderTypeAnthropic)
server, route := newServer(t, provider, aiGatewayProviderAuth{
Headers: map[string]string{"X-Api-Key": "sk-user"},
}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "claude-haiku-4-5"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "sk-user", got.xAPIKey)
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
require.Equal(t, "/v1/messages", got.path)
})
t.Run("NoUserKeyLeavesPlaceholderForAIBridged", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
server, route := newServer(t, provider, aiGatewayProviderAuth{}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "Bearer "+aibridgePlaceholderAPIKey, got.authorization)
require.Empty(t, got.xAPIKey)
require.Empty(t, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
})
}
func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
t.Parallel()
oldKeyID := uuid.NewString()
currentKeyID := uuid.NewString()
tests := []struct {
name string
messages []database.ChatMessage
wantKey string
wantOK bool
}{
{
name: "CurrentUserMessage",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
{ID: 3, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingCurrentUserAPIKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "SkipsUncompressedModelOnlyUserMessages",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: oldKeyID,
wantOK: true,
},
{
name: "CompressedSummaryFallback",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "LatestCompressedSummaryWins",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "VisibleUserWinsOverCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "UncompressedModelOnlyUserIgnored",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
},
{
name: "CompressedSummaryMissingKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
require.Equal(t, tt.wantOK, gotOK)
require.Equal(t, tt.wantKey, gotKey)
})
}
}
func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
oldKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
currentKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
modelOnlyKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(oldKey.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleSystem,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(currentKey.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
APIKeyID: sqlNullString(modelOnlyKey.ID),
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, currentKey.ID, gotKey)
}
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(key.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
APIKeyID: sqlNullString(key.ID),
})
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ids := make(map[int64]struct{}, len(messages))
for _, message := range messages {
ids[message.ID] = struct{}{}
}
_, hasVisibleUser := ids[visibleUser.ID]
require.False(t, hasVisibleUser)
_, hasSummary := ids[compressedSummary.ID]
require.True(t, hasSummary)
_, hasAfterSummary := ids[afterSummary.ID]
require.True(t, hasAfterSummary)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, key.ID, gotKey)
}
func sqlNullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""}
}
func TestAIBridgeRoutingFailClosed(t *testing.T) {
t.Parallel()
providerID := uuid.New()
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
aiProvider := aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)
t.Run("NilFactory", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "transport factory")
})
t.Run("FactoryError", func(t *testing.T) {
t.Parallel()
factory := &aibridgeTestFactory{err: xerrors.New("boom")}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "boom")
})
t.Run("MissingProviderName", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
missingNameProvider := aibridgeTestAIProvider(providerID, "", database.AiProviderTypeOpenai)
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(missingNameProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "AI provider name")
})
t.Run("MissingAPIKeyID", func(t *testing.T) {
t.Parallel()
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
t.Fatal("transport must not be used without an API key ID")
return nil, xerrors.New("unreachable")
})}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{})
require.ErrorContains(t, err, "active turn API key ID")
classified := chaterror.Classify(err)
require.Equal(t, codersdk.ChatErrorKindMissingKey, classified.Kind,
"production path must return a pre-classified missing_key error")
require.False(t, classified.Retryable)
})
t.Run("StaticModel", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "concrete AI provider")
})
}
func TestDirectModelBuildDoesNotRequireActiveAPIKeyID(t *testing.T) {
t.Parallel()
server := &Server{}
model, err := server.newModel(t.Context(), modelClientRequest{
Chat: database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
ModelName: "gpt-4",
UserAgent: chatprovider.UserAgent(),
}, newDirectModelRoute("openai", chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}), modelBuildOptions{})
require.NoError(t, err)
require.NotNil(t, model)
}
func TestAIBridgeComputerUseModelUsesRoute(t *testing.T) {
t.Parallel()
providerID := uuid.New()
apiKeyID := uuid.NewString()
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
t.Fatal("computer use model construction must not send a request")
return nil, xerrors.New("unreachable")
})}
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
provider := chattool.ComputerUseProviderOpenAI
modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider)
require.True(t, ok)
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
ctx,
chat,
aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)),
provider,
modelProvider,
modelName,
modelBuildOptions{ActiveAPIKeyID: apiKeyID},
)
require.NoError(t, err)
require.NotNil(t, model)
require.False(t, debugEnabled)
require.Equal(t, chattool.ComputerUseProviderOpenAI, resolvedProvider)
require.Equal(t, modelName, resolvedModel)
require.Equal(t, "primary-openai", factory.providerName)
require.Equal(t, aibridge.SourceAgents, factory.source)
}
func TestAIBridgeDelegatedContextPropagation(t *testing.T) {
t.Parallel()
providerID := uuid.New()
apiKeyID := uuid.NewString()
type seenRequest struct {
apiKeyID string
ok bool
path string
}
seen := make(chan seenRequest, 1)
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotAPIKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
seen <- seenRequest{
apiKeyID: gotAPIKeyID,
ok: ok,
path: req.URL.Path,
}
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
})}
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
model, err := server.newModel(ctx, aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)), modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "primary-openai", factory.providerName)
require.Equal(t, aibridge.SourceAgents, factory.source)
require.True(t, got.ok)
require.Equal(t, "/v1/responses", got.path)
require.Equal(t, apiKeyID, got.apiKeyID)
}