diff --git a/scaletest/chat/provider.go b/scaletest/chat/provider.go index cea691e5a3..8ff35f4322 100644 --- a/scaletest/chat/provider.go +++ b/scaletest/chat/provider.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/llmmock" "github.com/coder/retry" ) @@ -23,7 +24,7 @@ const ( scaletestModelName = "scaletest-model" scaletestModelDisplayName = "Scaletest Model" scaletestModelContextLimit = int64(4096) - scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/chat/completions" + scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/models" scaletestAIProviderProbeWait = 15 * time.Second scaletestAIProviderProbePeriod = 100 * time.Millisecond ) @@ -67,7 +68,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo } if providerAction != scaletestAIProviderActionReused { - if err := waitForScaletestAIProviderRoute(ctx, client, logger); err != nil { + if err := waitForScaletestAIProviderRoute(ctx, client, logger, llmMockURL); err != nil { return uuid.Nil, xerrors.Errorf("wait for mock LLM provider reload: %w", err) } } @@ -75,7 +76,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo return ensureScaletestChatModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, provider) } -func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger) error { +func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger, llmMockURL string) error { deploymentConfig, err := client.DeploymentConfig(ctx) if err != nil { return xerrors.Errorf("get deployment config: %w", err) @@ -86,38 +87,39 @@ func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Clien return nil } + expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL) + if err != nil { + return xerrors.Errorf("build mock LLM provider probe marker: %w", err) + } + logger.Info(ctx, "waiting for mock LLM provider reload", slog.F("provider_name", scaletestAIProviderName)) ctx, cancel := context.WithTimeout(ctx, scaletestAIProviderProbeWait) defer cancel() + var lastStatus int + var lastBody string for retrier := retry.New(scaletestAIProviderProbePeriod, scaletestAIProviderProbePeriod); retrier.Wait(ctx); { - res, err := client.Request(ctx, http.MethodPost, scaletestAIProviderProbePath, map[string]any{ - "model": scaletestModelName, - "messages": []map[string]string{{ - "role": "user", - "content": "ping", - }}, - "stream": false, - }, func(r *http.Request) { + res, err := client.Request(ctx, http.MethodGet, scaletestAIProviderProbePath, nil, func(r *http.Request) { r.Header.Set("Authorization", "Bearer "+client.SessionToken()) }) if err != nil { - return err + return xerrors.Errorf("probe mock LLM provider route: %w", err) } body, err := io.ReadAll(io.LimitReader(res.Body, 4096)) _ = res.Body.Close() if err != nil { return xerrors.Errorf("read probe response: %w", err) } - if res.StatusCode == http.StatusOK { + lastStatus = res.StatusCode + lastBody = strings.TrimSpace(string(body)) + if res.StatusCode == http.StatusOK && strings.Contains(lastBody, expectedMarker) { return nil } - if res.StatusCode != http.StatusNotFound || !strings.Contains(string(body), "route not supported") { - return xerrors.Errorf("status %d: %s", res.StatusCode, strings.TrimSpace(string(body))) - } - logger.Debug(ctx, "mock LLM provider route is not ready") + logger.Debug(ctx, "mock LLM provider route is not ready", + slog.F("status_code", res.StatusCode), + ) } - return xerrors.Errorf("timed out waiting for mock LLM provider route") + return xerrors.Errorf("timed out waiting for mock LLM provider route to report marker %q (last status %d: %s)", expectedMarker, lastStatus, lastBody) } func ensureScaletestChatModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, provider codersdk.AIProvider) (uuid.UUID, error) { diff --git a/scaletest/chat/provider_internal_test.go b/scaletest/chat/provider_internal_test.go new file mode 100644 index 0000000000..765124b82a --- /dev/null +++ b/scaletest/chat/provider_internal_test.go @@ -0,0 +1,192 @@ +package chat + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/llmmock" +) + +func TestEnsureScaletestChatModelConfig(t *testing.T) { + t.Parallel() + + ctx := context.Background() + providerID := uuid.MustParse("44444444-4444-4444-4444-444444444444") + wrongProviderID := uuid.MustParse("55555555-5555-5555-5555-555555555555") + matchingConfigID := uuid.MustParse("66666666-6666-6666-6666-666666666666") + createdConfigID := uuid.MustParse("77777777-7777-7777-7777-777777777777") + provider := codersdk.AIProvider{ID: providerID} + + t.Run("ReusesMatchingProviderAndModel", func(t *testing.T) { + t.Parallel() + + client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs": + writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{ + { + ID: uuid.MustParse("88888888-8888-8888-8888-888888888888"), + AIProviderID: &wrongProviderID, + Model: scaletestModelName, + Enabled: true, + }, + { + ID: matchingConfigID, + AIProviderID: &providerID, + Model: scaletestModelName, + Enabled: true, + }, + }) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + }))) + + gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider) + require.NoError(t, err) + require.Equal(t, matchingConfigID, gotID) + }) + + t.Run("CreatesWhenNoConfigMatches", func(t *testing.T) { + t.Parallel() + + var createReq codersdk.CreateChatModelConfigRequest + client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs": + writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{}) + case r.Method == http.MethodPost && r.URL.Path == "/api/experimental/chats/model-configs": + require.NoError(t, json.NewDecoder(r.Body).Decode(&createReq)) + writeJSON(t, rw, http.StatusCreated, codersdk.ChatModelConfig{ + ID: createdConfigID, + AIProviderID: &providerID, + Model: scaletestModelName, + Enabled: true, + }) + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + }))) + + gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider) + require.NoError(t, err) + require.Equal(t, createdConfigID, gotID) + require.Equal(t, &providerID, createReq.AIProviderID) + require.Equal(t, scaletestModelName, createReq.Model) + require.Equal(t, scaletestModelDisplayName, createReq.DisplayName) + require.NotNil(t, createReq.Enabled) + require.True(t, *createReq.Enabled) + require.NotNil(t, createReq.IsDefault) + require.False(t, *createReq.IsDefault) + require.NotNil(t, createReq.ContextLimit) + require.Equal(t, scaletestModelContextLimit, *createReq.ContextLimit) + }) +} + +func TestWaitForScaletestAIProviderRoute(t *testing.T) { + t.Parallel() + + llmMockURL := "http://new.example.test/v1" + expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL) + require.NoError(t, err) + + t.Run("WaitsUntilExpectedMarker", func(t *testing.T) { + t.Parallel() + + var probeCount atomic.Int64 + client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config": + writeDeploymentConfig(t, rw, true) + case r.Method == http.MethodGet && r.URL.Path == scaletestAIProviderProbePath: + require.Equal(t, "Bearer test-session", r.Header.Get("Authorization")) + switch probeCount.Add(1) { + case 1: + writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": "coder-scaletest-llm-mock:old.example.test/v1"}) + case 2: + writeText(rw, http.StatusBadGateway, "upstream proxy error") + case 3: + writeText(rw, http.StatusNotFound, "route not supported") + default: + writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": expectedMarker}) + } + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + + err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL) + require.NoError(t, err) + require.Equal(t, int64(4), probeCount.Load()) + }) + + t.Run("SkipsWhenAIGatewayRoutingDisabled", func(t *testing.T) { + t.Parallel() + + client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config": + writeDeploymentConfig(t, rw, false) + case r.URL.Path == scaletestAIProviderProbePath: + t.Fatal("probe route should not be called when AI Gateway routing is disabled") + default: + t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path) + } + })) + + err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL) + require.NoError(t, err) + }) +} + +func newProviderTestClient(t *testing.T, handler http.Handler) *codersdk.Client { + t.Helper() + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + serverURL, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(serverURL) + client.SetSessionToken("test-session") + return client +} + +func writeDeploymentConfig(t *testing.T, rw http.ResponseWriter, enabled bool) { + t.Helper() + + writeJSON(t, rw, http.StatusOK, map[string]any{ + "config": map[string]any{ + "ai": map[string]any{ + "bridge": map[string]any{ + "enabled": enabled, + }, + "chat": map[string]any{ + "ai_gateway_routing_enabled": enabled, + }, + }, + }, + }) +} + +func writeJSON(t *testing.T, rw http.ResponseWriter, status int, body any) { + t.Helper() + + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(status) + require.NoError(t, json.NewEncoder(rw).Encode(body)) +} + +func writeText(rw http.ResponseWriter, status int, body string) { + rw.Header().Set("Content-Type", "text/plain") + rw.WriteHeader(status) + _, _ = rw.Write([]byte(body)) +} diff --git a/scaletest/llmmock/server.go b/scaletest/llmmock/server.go index 8c9bdfe3c9..2d01a41d75 100644 --- a/scaletest/llmmock/server.go +++ b/scaletest/llmmock/server.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "net/url" "strings" "time" @@ -50,6 +51,24 @@ type Config struct { TraceEnable bool } +const probeMarkerPrefix = "coder-scaletest-llm-mock:" + +// ProbeMarkerForBaseURL returns the readiness marker emitted by the mock for a provider base URL. +func ProbeMarkerForBaseURL(rawURL string) (string, error) { + parsed, err := url.Parse(rawURL) + if err != nil { + return "", xerrors.Errorf("parse LLM mock URL: %w", err) + } + if parsed.Host == "" { + return "", xerrors.Errorf("LLM mock URL %q is missing a host", rawURL) + } + return probeMarker(parsed.Host, parsed.Path), nil +} + +func probeMarker(host string, basePath string) string { + return probeMarkerPrefix + host + strings.TrimRight(basePath, "/") +} + type llmRequest struct { Model string `json:"model"` Stream bool `json:"stream,omitempty"` @@ -172,6 +191,7 @@ func (s *Server) APIAddress() string { func (s *Server) startAPIServer(ctx context.Context) error { mux := http.NewServeMux() + mux.HandleFunc("GET /v1/models", s.handleOpenAIModels) mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI) mux.HandleFunc("POST /v1/responses", s.handleResponses) mux.HandleFunc("POST /v1/messages", s.handleAnthropic) @@ -201,6 +221,32 @@ func (s *Server) startAPIServer(ctx context.Context) error { return nil } +func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) { + pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) { + s.handleOpenAIModelsWithLabels(w, r.WithContext(ctx)) + }) +} + +func (s *Server) handleOpenAIModelsWithLabels(w http.ResponseWriter, r *http.Request) { + basePath := strings.TrimSuffix(r.URL.Path, "/") + basePath = strings.TrimSuffix(basePath, "/models") + resp := struct { + Object string `json:"object"` + Data []struct{} `json:"data"` + ScaletestLLMMock string `json:"scaletest_llm_mock"` + }{ + Object: "list", + Data: []struct{}{}, + ScaletestLLMMock: probeMarker(r.Host, basePath), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(resp); err != nil { + s.logger.Error(r.Context(), "failed to write OpenAI models response", slog.Error(err)) + } +} + func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) { pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) { s.handleOpenAIWithLabels(w, r.WithContext(ctx)) diff --git a/scaletest/llmmock/server_test.go b/scaletest/llmmock/server_test.go new file mode 100644 index 0000000000..5f6fb9d857 --- /dev/null +++ b/scaletest/llmmock/server_test.go @@ -0,0 +1,42 @@ +package llmmock_test + +import ( + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/scaletest/llmmock" +) + +func TestServerOpenAIModelsReportsProbeMarker(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server := new(llmmock.Server) + err := server.Start(ctx, llmmock.Config{ + Address: "127.0.0.1:0", + Logger: slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, server.Stop()) + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.APIAddress()+"/v1/models", nil) + require.NoError(t, err) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + expectedMarker, err := llmmock.ProbeMarkerForBaseURL(server.APIAddress() + "/v1") + require.NoError(t, err) + require.Contains(t, string(body), expectedMarker) +}