mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(scaletest): wait for mock AI provider reload
This commit is contained in:
+20
-18
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
@@ -23,7 +24,7 @@ const (
|
||||
scaletestModelName = "scaletest-model"
|
||||
scaletestModelDisplayName = "Scaletest Model"
|
||||
scaletestModelContextLimit = int64(4096)
|
||||
scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/chat/completions"
|
||||
scaletestAIProviderProbePath = "/api/v2/aibridge/" + scaletestAIProviderName + "/v1/models"
|
||||
scaletestAIProviderProbeWait = 15 * time.Second
|
||||
scaletestAIProviderProbePeriod = 100 * time.Millisecond
|
||||
)
|
||||
@@ -67,7 +68,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo
|
||||
}
|
||||
|
||||
if providerAction != scaletestAIProviderActionReused {
|
||||
if err := waitForScaletestAIProviderRoute(ctx, client, logger); err != nil {
|
||||
if err := waitForScaletestAIProviderRoute(ctx, client, logger, llmMockURL); err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("wait for mock LLM provider reload: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -75,7 +76,7 @@ func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.Client, lo
|
||||
return ensureScaletestChatModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, provider)
|
||||
}
|
||||
|
||||
func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger) error {
|
||||
func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Client, logger slog.Logger, llmMockURL string) error {
|
||||
deploymentConfig, err := client.DeploymentConfig(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get deployment config: %w", err)
|
||||
@@ -86,38 +87,39 @@ func waitForScaletestAIProviderRoute(ctx context.Context, client *codersdk.Clien
|
||||
return nil
|
||||
}
|
||||
|
||||
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("build mock LLM provider probe marker: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(ctx, "waiting for mock LLM provider reload", slog.F("provider_name", scaletestAIProviderName))
|
||||
ctx, cancel := context.WithTimeout(ctx, scaletestAIProviderProbeWait)
|
||||
defer cancel()
|
||||
|
||||
var lastStatus int
|
||||
var lastBody string
|
||||
for retrier := retry.New(scaletestAIProviderProbePeriod, scaletestAIProviderProbePeriod); retrier.Wait(ctx); {
|
||||
res, err := client.Request(ctx, http.MethodPost, scaletestAIProviderProbePath, map[string]any{
|
||||
"model": scaletestModelName,
|
||||
"messages": []map[string]string{{
|
||||
"role": "user",
|
||||
"content": "ping",
|
||||
}},
|
||||
"stream": false,
|
||||
}, func(r *http.Request) {
|
||||
res, err := client.Request(ctx, http.MethodGet, scaletestAIProviderProbePath, nil, func(r *http.Request) {
|
||||
r.Header.Set("Authorization", "Bearer "+client.SessionToken())
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return xerrors.Errorf("probe mock LLM provider route: %w", err)
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||
_ = res.Body.Close()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read probe response: %w", err)
|
||||
}
|
||||
if res.StatusCode == http.StatusOK {
|
||||
lastStatus = res.StatusCode
|
||||
lastBody = strings.TrimSpace(string(body))
|
||||
if res.StatusCode == http.StatusOK && strings.Contains(lastBody, expectedMarker) {
|
||||
return nil
|
||||
}
|
||||
if res.StatusCode != http.StatusNotFound || !strings.Contains(string(body), "route not supported") {
|
||||
return xerrors.Errorf("status %d: %s", res.StatusCode, strings.TrimSpace(string(body)))
|
||||
logger.Debug(ctx, "mock LLM provider route is not ready",
|
||||
slog.F("status_code", res.StatusCode),
|
||||
)
|
||||
}
|
||||
logger.Debug(ctx, "mock LLM provider route is not ready")
|
||||
}
|
||||
return xerrors.Errorf("timed out waiting for mock LLM provider route")
|
||||
return xerrors.Errorf("timed out waiting for mock LLM provider route to report marker %q (last status %d: %s)", expectedMarker, lastStatus, lastBody)
|
||||
}
|
||||
|
||||
func ensureScaletestChatModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, provider codersdk.AIProvider) (uuid.UUID, error) {
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
)
|
||||
|
||||
func TestEnsureScaletestChatModelConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
providerID := uuid.MustParse("44444444-4444-4444-4444-444444444444")
|
||||
wrongProviderID := uuid.MustParse("55555555-5555-5555-5555-555555555555")
|
||||
matchingConfigID := uuid.MustParse("66666666-6666-6666-6666-666666666666")
|
||||
createdConfigID := uuid.MustParse("77777777-7777-7777-7777-777777777777")
|
||||
provider := codersdk.AIProvider{ID: providerID}
|
||||
|
||||
t.Run("ReusesMatchingProviderAndModel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs":
|
||||
writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{
|
||||
{
|
||||
ID: uuid.MustParse("88888888-8888-8888-8888-888888888888"),
|
||||
AIProviderID: &wrongProviderID,
|
||||
Model: scaletestModelName,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: matchingConfigID,
|
||||
AIProviderID: &providerID,
|
||||
Model: scaletestModelName,
|
||||
Enabled: true,
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
})))
|
||||
|
||||
gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, matchingConfigID, gotID)
|
||||
})
|
||||
|
||||
t.Run("CreatesWhenNoConfigMatches", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var createReq codersdk.CreateChatModelConfigRequest
|
||||
client := codersdk.NewExperimentalClient(newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/chats/model-configs":
|
||||
writeJSON(t, rw, http.StatusOK, []codersdk.ChatModelConfig{})
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/experimental/chats/model-configs":
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&createReq))
|
||||
writeJSON(t, rw, http.StatusCreated, codersdk.ChatModelConfig{
|
||||
ID: createdConfigID,
|
||||
AIProviderID: &providerID,
|
||||
Model: scaletestModelName,
|
||||
Enabled: true,
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
})))
|
||||
|
||||
gotID, err := ensureScaletestChatModelConfig(ctx, client, testLogger(), provider)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, createdConfigID, gotID)
|
||||
require.Equal(t, &providerID, createReq.AIProviderID)
|
||||
require.Equal(t, scaletestModelName, createReq.Model)
|
||||
require.Equal(t, scaletestModelDisplayName, createReq.DisplayName)
|
||||
require.NotNil(t, createReq.Enabled)
|
||||
require.True(t, *createReq.Enabled)
|
||||
require.NotNil(t, createReq.IsDefault)
|
||||
require.False(t, *createReq.IsDefault)
|
||||
require.NotNil(t, createReq.ContextLimit)
|
||||
require.Equal(t, scaletestModelContextLimit, *createReq.ContextLimit)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForScaletestAIProviderRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
llmMockURL := "http://new.example.test/v1"
|
||||
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(llmMockURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("WaitsUntilExpectedMarker", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var probeCount atomic.Int64
|
||||
client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config":
|
||||
writeDeploymentConfig(t, rw, true)
|
||||
case r.Method == http.MethodGet && r.URL.Path == scaletestAIProviderProbePath:
|
||||
require.Equal(t, "Bearer test-session", r.Header.Get("Authorization"))
|
||||
switch probeCount.Add(1) {
|
||||
case 1:
|
||||
writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": "coder-scaletest-llm-mock:old.example.test/v1"})
|
||||
case 2:
|
||||
writeText(rw, http.StatusBadGateway, "upstream proxy error")
|
||||
case 3:
|
||||
writeText(rw, http.StatusNotFound, "route not supported")
|
||||
default:
|
||||
writeJSON(t, rw, http.StatusOK, map[string]string{"scaletest_llm_mock": expectedMarker})
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
|
||||
err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(4), probeCount.Load())
|
||||
})
|
||||
|
||||
t.Run("SkipsWhenAIGatewayRoutingDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := newProviderTestClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/deployment/config":
|
||||
writeDeploymentConfig(t, rw, false)
|
||||
case r.URL.Path == scaletestAIProviderProbePath:
|
||||
t.Fatal("probe route should not be called when AI Gateway routing is disabled")
|
||||
default:
|
||||
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
|
||||
err := waitForScaletestAIProviderRoute(context.Background(), client, testLogger(), llmMockURL)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func newProviderTestClient(t *testing.T, handler http.Handler) *codersdk.Client {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
client := codersdk.New(serverURL)
|
||||
client.SetSessionToken("test-session")
|
||||
return client
|
||||
}
|
||||
|
||||
func writeDeploymentConfig(t *testing.T, rw http.ResponseWriter, enabled bool) {
|
||||
t.Helper()
|
||||
|
||||
writeJSON(t, rw, http.StatusOK, map[string]any{
|
||||
"config": map[string]any{
|
||||
"ai": map[string]any{
|
||||
"bridge": map[string]any{
|
||||
"enabled": enabled,
|
||||
},
|
||||
"chat": map[string]any{
|
||||
"ai_gateway_routing_enabled": enabled,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, rw http.ResponseWriter, status int, body any) {
|
||||
t.Helper()
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(status)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(body))
|
||||
}
|
||||
|
||||
func writeText(rw http.ResponseWriter, status int, body string) {
|
||||
rw.Header().Set("Content-Type", "text/plain")
|
||||
rw.WriteHeader(status)
|
||||
_, _ = rw.Write([]byte(body))
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -50,6 +51,24 @@ type Config struct {
|
||||
TraceEnable bool
|
||||
}
|
||||
|
||||
const probeMarkerPrefix = "coder-scaletest-llm-mock:"
|
||||
|
||||
// ProbeMarkerForBaseURL returns the readiness marker emitted by the mock for a provider base URL.
|
||||
func ProbeMarkerForBaseURL(rawURL string) (string, error) {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse LLM mock URL: %w", err)
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return "", xerrors.Errorf("LLM mock URL %q is missing a host", rawURL)
|
||||
}
|
||||
return probeMarker(parsed.Host, parsed.Path), nil
|
||||
}
|
||||
|
||||
func probeMarker(host string, basePath string) string {
|
||||
return probeMarkerPrefix + host + strings.TrimRight(basePath, "/")
|
||||
}
|
||||
|
||||
type llmRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
@@ -172,6 +191,7 @@ func (s *Server) APIAddress() string {
|
||||
func (s *Server) startAPIServer(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("GET /v1/models", s.handleOpenAIModels)
|
||||
mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI)
|
||||
mux.HandleFunc("POST /v1/responses", s.handleResponses)
|
||||
mux.HandleFunc("POST /v1/messages", s.handleAnthropic)
|
||||
@@ -201,6 +221,32 @@ func (s *Server) startAPIServer(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
|
||||
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
|
||||
s.handleOpenAIModelsWithLabels(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleOpenAIModelsWithLabels(w http.ResponseWriter, r *http.Request) {
|
||||
basePath := strings.TrimSuffix(r.URL.Path, "/")
|
||||
basePath = strings.TrimSuffix(basePath, "/models")
|
||||
resp := struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct{} `json:"data"`
|
||||
ScaletestLLMMock string `json:"scaletest_llm_mock"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: []struct{}{},
|
||||
ScaletestLLMMock: probeMarker(r.Host, basePath),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error(r.Context(), "failed to write OpenAI models response", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) {
|
||||
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
|
||||
s.handleOpenAIWithLabels(w, r.WithContext(ctx))
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package llmmock_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
)
|
||||
|
||||
func TestServerOpenAIModelsReportsProbeMarker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
server := new(llmmock.Server)
|
||||
err := server.Start(ctx, llmmock.Config{
|
||||
Address: "127.0.0.1:0",
|
||||
Logger: slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.APIAddress()+"/v1/models", nil)
|
||||
require.NoError(t, err)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
expectedMarker, err := llmmock.ProbeMarkerForBaseURL(server.APIAddress() + "/v1")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), expectedMarker)
|
||||
}
|
||||
Reference in New Issue
Block a user