fix(scaletest): wait for mock AI provider reload

This commit is contained in:
Ethan Dickson
2026-06-02 13:18:13 +00:00
parent 8aac8380ce
commit 2769d5f125
4 changed files with 300 additions and 18 deletions
+20 -18
View File
@@ -12,6 +12,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/llmmock"
"github.com/coder/retry"
)
@@ -23,7 +24,7 @@ const (
scaletestModelName = "scaletest-model"
scaletestModelDisplayName = "Scaletest Model"
scaletestModelContextLimit = int64(4096)
scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/chat/completions"
scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/models"
scaletestAIProviderProbeWait = 15 * time.Second
scaletestAIProviderProbePeriod = 100 * time.Millisecond
)
@@ -67,7 +68,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo
}
if providerAction != scaletestAIProviderActionReused {
if err := waitForScaletestAIProviderRoute(ctx, client, logger); err != nil {
if err := waitForScaletestAIProviderRoute(ctx, client, logger, llmMockURL); err != nil {
return uuid.Nil, xerrors.Errorf("wait for mock LLM provider reload: %w", err)
}
}
@@ -75,7 +76,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo
return ensureScaletestChatModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, provider)
}
func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger) error {
func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger, llmMockURL string) error {
deploymentConfig, err := client.DeploymentConfig(ctx)
if err != nil {
return xerrors.Errorf("get deployment config: %w", err)
@@ -86,38 +87,39 @@ func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Clien
return nil
}
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL)
if err != nil {
return xerrors.Errorf("build mock LLM provider probe marker: %w", err)
}
logger.Info(ctx, "waiting for mock LLM provider reload", slog.F("provider_name", scaletestAIProviderName))
ctx, cancel := context.WithTimeout(ctx, scaletestAIProviderProbeWait)
defer cancel()
var lastStatus int
var lastBody string
for retrier := retry.New(scaletestAIProviderProbePeriod, scaletestAIProviderProbePeriod); retrier.Wait(ctx); {
res, err := client.Request(ctx, http.MethodPost, scaletestAIProviderProbePath, map[string]any{
"model": scaletestModelName,
"messages": []map[string]string{{
"role": "user",
"content": "ping",
}},
"stream": false,
}, func(r *http.Request) {
res, err := client.Request(ctx, http.MethodGet, scaletestAIProviderProbePath, nil, func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+client.SessionToken())
})
if err != nil {
return err
return xerrors.Errorf("probe mock LLM provider route: %w", err)
}
body, err := io.ReadAll(io.LimitReader(res.Body, 4096))
_ = res.Body.Close()
if err != nil {
return xerrors.Errorf("read probe response: %w", err)
}
if res.StatusCode == http.StatusOK {
lastStatus = res.StatusCode
lastBody = strings.TrimSpace(string(body))
if res.StatusCode == http.StatusOK && strings.Contains(lastBody, expectedMarker) {
return nil
}
if res.StatusCode != http.StatusNotFound || !strings.Contains(string(body), "route not supported") {
return xerrors.Errorf("status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
}
logger.Debug(ctx, "mock LLM provider route is not ready")
logger.Debug(ctx, "mock LLM provider route is not ready",
slog.F("status_code", res.StatusCode),
)
}
return xerrors.Errorf("timed out waiting for mock LLM provider route")
return xerrors.Errorf("timed out waiting for mock LLM provider route to report marker %q (last status %d: %s)", expectedMarker, lastStatus, lastBody)
}
func ensureScaletestChatModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, provider codersdk.AIProvider) (uuid.UUID, error) {
+192
View File
@@ -0,0 +1,192 @@
package chat
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/llmmock"
)
func TestEnsureScaletestChatModelConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()
providerID := uuid.MustParse("44444444-4444-4444-4444-444444444444")
wrongProviderID := uuid.MustParse("55555555-5555-5555-5555-555555555555")
matchingConfigID := uuid.MustParse("66666666-6666-6666-6666-666666666666")
createdConfigID := uuid.MustParse("77777777-7777-7777-7777-777777777777")
provider := codersdk.AIProvider{ID: providerID}
t.Run("ReusesMatchingProviderAndModel", func(t *testing.T) {
t.Parallel()
client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs":
writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{
{
ID: uuid.MustParse("88888888-8888-8888-8888-888888888888"),
AIProviderID: &wrongProviderID,
Model: scaletestModelName,
Enabled: true,
},
{
ID: matchingConfigID,
AIProviderID: &providerID,
Model: scaletestModelName,
Enabled: true,
},
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
})))
gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider)
require.NoError(t, err)
require.Equal(t, matchingConfigID, gotID)
})
t.Run("CreatesWhenNoConfigMatches", func(t *testing.T) {
t.Parallel()
var createReq codersdk.CreateChatModelConfigRequest
client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs":
writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{})
case r.Method == http.MethodPost && r.URL.Path == "/api/experimental/chats/model-configs":
require.NoError(t, json.NewDecoder(r.Body).Decode(&createReq))
writeJSON(t, rw, http.StatusCreated, codersdk.ChatModelConfig{
ID: createdConfigID,
AIProviderID: &providerID,
Model: scaletestModelName,
Enabled: true,
})
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
})))
gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider)
require.NoError(t, err)
require.Equal(t, createdConfigID, gotID)
require.Equal(t, &providerID, createReq.AIProviderID)
require.Equal(t, scaletestModelName, createReq.Model)
require.Equal(t, scaletestModelDisplayName, createReq.DisplayName)
require.NotNil(t, createReq.Enabled)
require.True(t, *createReq.Enabled)
require.NotNil(t, createReq.IsDefault)
require.False(t, *createReq.IsDefault)
require.NotNil(t, createReq.ContextLimit)
require.Equal(t, scaletestModelContextLimit, *createReq.ContextLimit)
})
}
func TestWaitForScaletestAIProviderRoute(t *testing.T) {
t.Parallel()
llmMockURL := "http://new.example.test/v1"
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL)
require.NoError(t, err)
t.Run("WaitsUntilExpectedMarker", func(t *testing.T) {
t.Parallel()
var probeCount atomic.Int64
client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config":
writeDeploymentConfig(t, rw, true)
case r.Method == http.MethodGet && r.URL.Path == scaletestAIProviderProbePath:
require.Equal(t, "Bearer test-session", r.Header.Get("Authorization"))
switch probeCount.Add(1) {
case 1:
writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": "coder-scaletest-llm-mock:old.example.test/v1"})
case 2:
writeText(rw, http.StatusBadGateway, "upstream proxy error")
case 3:
writeText(rw, http.StatusNotFound, "route not supported")
default:
writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": expectedMarker})
}
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL)
require.NoError(t, err)
require.Equal(t, int64(4), probeCount.Load())
})
t.Run("SkipsWhenAIGatewayRoutingDisabled", func(t *testing.T) {
t.Parallel()
client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config":
writeDeploymentConfig(t, rw, false)
case r.URL.Path == scaletestAIProviderProbePath:
t.Fatal("probe route should not be called when AI Gateway routing is disabled")
default:
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
}
}))
err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL)
require.NoError(t, err)
})
}
func newProviderTestClient(t *testing.T, handler http.Handler) *codersdk.Client {
t.Helper()
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(serverURL)
client.SetSessionToken("test-session")
return client
}
func writeDeploymentConfig(t *testing.T, rw http.ResponseWriter, enabled bool) {
t.Helper()
writeJSON(t, rw, http.StatusOK, map[string]any{
"config": map[string]any{
"ai": map[string]any{
"bridge": map[string]any{
"enabled": enabled,
},
"chat": map[string]any{
"ai_gateway_routing_enabled": enabled,
},
},
},
})
}
func writeJSON(t *testing.T, rw http.ResponseWriter, status int, body any) {
t.Helper()
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(status)
require.NoError(t, json.NewEncoder(rw).Encode(body))
}
func writeText(rw http.ResponseWriter, status int, body string) {
rw.Header().Set("Content-Type", "text/plain")
rw.WriteHeader(status)
_, _ = rw.Write([]byte(body))
}
+46
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
@@ -50,6 +51,24 @@ type Config struct {
TraceEnable bool
}
const probeMarkerPrefix = "coder-scaletest-llm-mock:"
// ProbeMarkerForBaseURL returns the readiness marker emitted by the mock for a provider base URL.
func ProbeMarkerForBaseURL(rawURL string) (string, error) {
parsed, err := url.Parse(rawURL)
if err != nil {
return "", xerrors.Errorf("parse LLM mock URL: %w", err)
}
if parsed.Host == "" {
return "", xerrors.Errorf("LLM mock URL %q is missing a host", rawURL)
}
return probeMarker(parsed.Host, parsed.Path), nil
}
func probeMarker(host string, basePath string) string {
return probeMarkerPrefix + host + strings.TrimRight(basePath, "/")
}
type llmRequest struct {
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
@@ -172,6 +191,7 @@ func (s *Server) APIAddress() string {
func (s *Server) startAPIServer(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/models", s.handleOpenAIModels)
mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI)
mux.HandleFunc("POST /v1/responses", s.handleResponses)
mux.HandleFunc("POST /v1/messages", s.handleAnthropic)
@@ -201,6 +221,32 @@ func (s *Server) startAPIServer(ctx context.Context) error {
return nil
}
func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
s.handleOpenAIModelsWithLabels(w, r.WithContext(ctx))
})
}
func (s *Server) handleOpenAIModelsWithLabels(w http.ResponseWriter, r *http.Request) {
basePath := strings.TrimSuffix(r.URL.Path, "/")
basePath = strings.TrimSuffix(basePath, "/models")
resp := struct {
Object string `json:"object"`
Data []struct{} `json:"data"`
ScaletestLLMMock string `json:"scaletest_llm_mock"`
}{
Object: "list",
Data: []struct{}{},
ScaletestLLMMock: probeMarker(r.Host, basePath),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error(r.Context(), "failed to write OpenAI models response", slog.Error(err))
}
}
func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) {
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
s.handleOpenAIWithLabels(w, r.WithContext(ctx))
+42
View File
@@ -0,0 +1,42 @@
package llmmock_test
import (
"context"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/scaletest/llmmock"
)
func TestServerOpenAIModelsReportsProbeMarker(t *testing.T) {
t.Parallel()
ctx := context.Background()
server := new(llmmock.Server)
err := server.Start(ctx, llmmock.Config{
Address: "127.0.0.1:0",
Logger: slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, server.Stop())
})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.APIAddress()+"/v1/models", nil)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(server.APIAddress() + "/v1")
require.NoError(t, err)
require.Contains(t, string(body), expectedMarker)
}