From f529577beec3d35ff6f8974bf20fe404269efcf3 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 28 May 2026 10:27:32 +0200 Subject: [PATCH] fix(coderd/x/chatd): harden openai-compatible chat calls (#25737) OpenAI-compatible chat paths hit two provider compatibility issues. Some compatible endpoints reject a named `tool_choice` when there is only one tool, and Gemini's OpenAI-compatible endpoint requires thought signatures on current-turn tool calls. Centralize OpenAI-compatible request patches in the chat provider: rewrite single named tool choices to `"required"`, and add the documented dummy Google thought signature to the first tool call in each current-turn tool step for Gemini routes. Vercel OpenAI-compatible requests are left unchanged for the thought-signature patch. > Mux created this PR on behalf of Mike. --- coderd/x/chatd/chatprovider/chatprovider.go | 1 + .../chatprovider/openai_compat_patches.go | 237 ++++++++++++++++++ .../openai_compat_patches_internal_test.go | 156 ++++++++++++ .../openai_compat_patches_test.go | 186 ++++++++++++++ coderd/x/chatd/quickgen_internal_test.go | 118 ++++++++- 5 files changed, 695 insertions(+), 3 deletions(-) create mode 100644 coderd/x/chatd/chatprovider/openai_compat_patches.go create mode 100644 coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go create mode 100644 coderd/x/chatd/chatprovider/openai_compat_patches_test.go diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index 768aa5774e..fec0840b08 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -1243,6 +1243,7 @@ func ModelFromConfig( } providerClient, err = fantasyopenai.New(options...) case fantasyopenaicompat.Name: + httpClient = withOpenAICompatRequestPatches(httpClient, baseURL, modelID) options := []fantasyopenaicompat.Option{ fantasyopenaicompat.WithAPIKey(apiKey), fantasyopenaicompat.WithUserAgent(userAgent), diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches.go b/coderd/x/chatd/chatprovider/openai_compat_patches.go new file mode 100644 index 0000000000..beac165bb5 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches.go @@ -0,0 +1,237 @@ +package chatprovider + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" +) + +// OpenAI-compatible providers share an API shape but differ in the exact JSON +// they accept. These patches adjust Fantasy's serialized request body at the +// transport boundary so higher-level generation code can stay provider agnostic. +// +// googleOpenAICompatDummyThoughtSignature is Google's documented last-resort +// bypass for callers that cannot preserve a real Gemini thought signature. +// See https://ai.google.dev/gemini-api/docs/thought-signatures. +const googleOpenAICompatDummyThoughtSignature = "skip_thought_signature_validator" + +func withOpenAICompatRequestPatches( + client *http.Client, + baseURL string, + modelID string, +) *http.Client { + if client == nil { + client = &http.Client{} + } else { + clone := *client + client = &clone + } + client.Transport = &openAICompatRequestPatchTransport{ + Base: client.Transport, + BaseURL: baseURL, + ModelID: modelID, + } + return client +} + +type openAICompatRequestPatchTransport struct { + Base http.RoundTripper + // BaseURL is the configured provider base URL, used to detect direct Gemini endpoints. + BaseURL string + // ModelID is the configured model ID, used to detect Gemini routes through Coder AI Bridge. + ModelID string +} + +func (t *openAICompatRequestPatchTransport) RoundTrip(req *http.Request) (*http.Response, error) { + base := t.base() + if !shouldPatchOpenAICompatRequest(req) { + return base.RoundTrip(req) + } + + body, err := io.ReadAll(req.Body) + closeErr := req.Body.Close() + if err != nil { + return nil, err + } + if closeErr != nil { + return nil, closeErr + } + + patched := patchOpenAICompatChatCompletionsBody(body, t.BaseURL, t.ModelID) + patchedReq := req.Clone(req.Context()) + patchedReq.Body = io.NopCloser(bytes.NewReader(patched)) + patchedReq.ContentLength = int64(len(patched)) + patchedReq.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(patched)), nil + } + + return base.RoundTrip(patchedReq) +} + +func (t *openAICompatRequestPatchTransport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func shouldPatchOpenAICompatRequest(req *http.Request) bool { + return req != nil && + req.Method == http.MethodPost && + req.Body != nil && + strings.HasSuffix(req.URL.Path, "/chat/completions") +} + +func patchOpenAICompatChatCompletionsBody(body []byte, baseURL string, modelID string) []byte { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + + changed := rewriteOpenAICompatSingleToolChoice(payload) + if shouldAddGoogleOpenAICompatThoughtSignatures(baseURL, modelID) { + changed = addGoogleOpenAICompatThoughtSignatures(payload) || changed + } + if !changed { + return body + } + + patched, err := json.Marshal(payload) + if err != nil { + return body + } + return patched +} + +// rewriteOpenAICompatSingleToolChoice replaces a single named tool choice with +// "required" because some compatible endpoints reject the named object form. +func rewriteOpenAICompatSingleToolChoice(payload map[string]any) bool { + tools, ok := payload["tools"].([]any) + if !ok || len(tools) != 1 { + return false + } + tool, ok := tools[0].(map[string]any) + if !ok { + return false + } + function, ok := tool["function"].(map[string]any) + if !ok { + return false + } + toolName, _ := function["name"].(string) + if toolName == "" { + return false + } + + toolChoice, ok := payload["tool_choice"].(map[string]any) + if !ok { + return false + } + if toolType, _ := toolChoice["type"].(string); toolType != "function" { + return false + } + choiceFunction, ok := toolChoice["function"].(map[string]any) + if !ok { + return false + } + choiceName, _ := choiceFunction["name"].(string) + if choiceName != toolName { + return false + } + + payload["tool_choice"] = "required" + return true +} + +// shouldAddGoogleOpenAICompatThoughtSignatures detects direct Gemini OpenAI +// endpoints and Coder AI Bridge Gemini routes. Other gateways, such as Vercel, +// keep their own provider-specific compatibility behavior. +func shouldAddGoogleOpenAICompatThoughtSignatures(baseURL string, modelID string) bool { + parsed, err := url.Parse(baseURL) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + path := strings.ToLower(parsed.EscapedPath()) + if host == "generativelanguage.googleapis.com" && strings.Contains(path, "/openai") { + return true + } + return host == "coder-aibridge" && isGeminiModelID(modelID) +} + +func isGeminiModelID(modelID string) bool { + modelID = strings.ToLower(strings.TrimSpace(modelID)) + return strings.HasPrefix(modelID, "gemini-") || strings.Contains(modelID, "/gemini-") +} + +// addGoogleOpenAICompatThoughtSignatures adds a dummy thought signature to the +// first tool call on each assistant tool-call message in the latest user turn. +// Gemini validates tool-call history with thought signatures, but +// OpenAI-compatible serialization can drop the original provider metadata. +func addGoogleOpenAICompatThoughtSignatures(payload map[string]any) bool { + messages, ok := payload["messages"].([]any) + if !ok { + return false + } + + currentTurnStart := -1 + for i, raw := range messages { + message, ok := raw.(map[string]any) + if !ok { + continue + } + if role, _ := message["role"].(string); role == "user" { + currentTurnStart = i + } + } + + if currentTurnStart == -1 { + return false + } + + changed := false + for _, raw := range messages[currentTurnStart+1:] { + message, ok := raw.(map[string]any) + if !ok || !isOpenAICompatAssistantRole(message["role"]) { + continue + } + toolCalls, ok := message["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + continue + } + firstToolCall, ok := toolCalls[0].(map[string]any) + if !ok { + continue + } + if ensureGoogleOpenAICompatThoughtSignature(firstToolCall) { + changed = true + } + } + return changed +} + +func isOpenAICompatAssistantRole(role any) bool { + roleValue, _ := role.(string) + return roleValue == "assistant" || roleValue == "model" +} + +func ensureGoogleOpenAICompatThoughtSignature(toolCall map[string]any) bool { + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + if signature, _ := google["thought_signature"].(string); signature != "" { + return false + } + if extraContent == nil { + extraContent = map[string]any{} + toolCall["extra_content"] = extraContent + } + if google == nil { + google = map[string]any{} + extraContent["google"] = google + } + google["thought_signature"] = googleOpenAICompatDummyThoughtSignature + return true +} diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go b/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go new file mode 100644 index 0000000000..eace6c4173 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches_internal_test.go @@ -0,0 +1,156 @@ +//nolint:testpackage // These tests cover unexported request-patch guards. +package chatprovider + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPatchOpenAICompatChatCompletionsBody_Guards(t *testing.T) { + t.Parallel() + + t.Run("leaves multi tool specific choice unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "tools": []any{ + functionTool("first_tool"), + functionTool("second_tool"), + }, + "tool_choice": map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "first_tool", + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model") + body := decodeJSONMap(t, patched) + toolChoice, ok := body["tool_choice"].(map[string]any) + require.True(t, ok) + function, ok := toolChoice["function"].(map[string]any) + require.True(t, ok) + require.Equal(t, "first_tool", function["name"]) + }) + + t.Run("leaves string tool choice unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "tools": []any{functionTool("first_tool")}, + "tool_choice": "auto", + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model") + body := decodeJSONMap(t, patched) + require.Equal(t, "auto", body["tool_choice"]) + }) + + t.Run("leaves Gemini assistant history without a user turn unchanged", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "messages": []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + functionToolCall("call_without_user", "history_tool"), + }, + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + body := decodeJSONMap(t, patched) + messages := body["messages"].([]any) + require.Empty(t, googleThoughtSignature(t, messages[0], 0)) + }) + + t.Run("preserves existing Gemini thought signature", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "messages": []any{ + map[string]any{"role": "user", "content": "current turn"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_with_signature", + "type": "function", + "function": map[string]any{ + "name": "signed_tool", + "arguments": `{}`, + }, + "extra_content": map[string]any{ + "google": map[string]any{ + "thought_signature": "real-signature", + }, + }, + }, + }, + }, + }, + } + + patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + body := decodeJSONMap(t, patched) + messages := body["messages"].([]any) + require.Equal(t, "real-signature", googleThoughtSignature(t, messages[1], 0)) + }) +} + +func functionTool(name string) map[string]any { + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + } +} + +func functionToolCall(id string, name string) map[string]any { + return map[string]any{ + "id": id, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": `{}`, + }, + } +} + +func mustJSON(t *testing.T, payload map[string]any) []byte { + t.Helper() + + body, err := json.Marshal(payload) + require.NoError(t, err) + return body +} + +func decodeJSONMap(t *testing.T, body []byte) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(body, &payload)) + return payload +} + +func googleThoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string { + t.Helper() + + message, ok := rawMessage.(map[string]any) + require.True(t, ok) + toolCalls, ok := message["tool_calls"].([]any) + require.True(t, ok) + require.Greater(t, len(toolCalls), toolCallIndex) + toolCall, ok := toolCalls[toolCallIndex].(map[string]any) + require.True(t, ok) + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + signature, _ := google["thought_signature"].(string) + return signature +} diff --git a/coderd/x/chatd/chatprovider/openai_compat_patches_test.go b/coderd/x/chatd/chatprovider/openai_compat_patches_test.go new file mode 100644 index 0000000000..c6042c0c63 --- /dev/null +++ b/coderd/x/chatd/chatprovider/openai_compat_patches_test.go @@ -0,0 +1,186 @@ +package chatprovider_test + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "charm.land/fantasy" + fantasyopenaicompat "charm.land/fantasy/providers/openaicompat" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +const dummyThoughtSignature = "skip_thought_signature_validator" + +func TestModelFromConfig_GeminiOpenAICompatThoughtSignatures(t *testing.T) { + t.Parallel() + + t.Run("Gemini endpoint receives current turn thought signature", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Empty(t, thoughtSignature(t, messages[1], 0)) + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0)) + require.Empty(t, thoughtSignature(t, messages[4], 1)) + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[6], 0)) + }) + + t.Run("Coder AI Bridge Gemini route receives current turn thought signature", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "http://coder-aibridge/v1", "gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0)) + }) + + t.Run("Vercel OpenAI-compatible Gemini route is unchanged", func(t *testing.T) { + t.Parallel() + + body := generateOpenAICompatRequest(t, "https://gateway.vercel.ai/v1", "google/gemini-3.5-flash") + messages := body["messages"].([]any) + + require.Empty(t, thoughtSignature(t, messages[4], 0)) + }) +} + +func generateOpenAICompatRequest(t *testing.T, baseURL string, modelID string) map[string]any { + t.Helper() + + transport := &captureChatCompletionTransport{} + model, err := chatprovider.ModelFromConfig( + fantasyopenaicompat.Name, + modelID, + chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyopenaicompat.Name: "test-key", + }, + BaseURLByProvider: map[string]string{ + fantasyopenaicompat.Name: baseURL, + }, + }, + chatprovider.UserAgent(), + nil, + &http.Client{Transport: transport}, + ) + require.NoError(t, err) + + _, err = model.Generate(t.Context(), fantasy.Call{ + Prompt: geminiOpenAICompatToolPrompt(), + }) + require.NoError(t, err) + require.NotNil(t, transport.body) + return transport.body +} + +type captureChatCompletionTransport struct { + body map[string]any +} + +func (ct *captureChatCompletionTransport) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + if strings.HasSuffix(req.URL.Path, "/chat/completions") { + ct.body = map[string]any{} + if err := json.Unmarshal(body, &ct.body); err != nil { + return nil, err + } + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{ + "id":"chatcmpl-test", + "object":"chat.completion", + "created":0, + "model":"gemini-3.5-flash", + "choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`)), + }, nil +} + +func geminiOpenAICompatToolPrompt() []fantasy.Message { + return []fantasy.Message{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "previous turn"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ToolCallID: "previous-call", ToolName: "previous_tool", Input: `{}`}, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "previous-call", + Output: fantasy.ToolResultOutputContentText{Text: `{}`}, + }, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "current turn"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ToolCallID: "current-call-a", ToolName: "first_tool", Input: `{}`}, + fantasy.ToolCallPart{ToolCallID: "current-call-b", ToolName: "parallel_tool", Input: `{}`}, + }, + }, + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "current-call-a", + Output: fantasy.ToolResultOutputContentText{Text: `{}`}, + }, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: "current-call-c", + ToolName: "second_step_tool", + Input: `{}`, + }, + }, + }, + } +} + +func thoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string { + t.Helper() + message, ok := rawMessage.(map[string]any) + require.True(t, ok) + toolCalls, ok := message["tool_calls"].([]any) + require.True(t, ok) + require.Greater(t, len(toolCalls), toolCallIndex) + toolCall, ok := toolCalls[toolCallIndex].(map[string]any) + require.True(t, ok) + extraContent, _ := toolCall["extra_content"].(map[string]any) + google, _ := extraContent["google"].(map[string]any) + signature, _ := google["thought_signature"].(string) + return signature +} diff --git a/coderd/x/chatd/quickgen_internal_test.go b/coderd/x/chatd/quickgen_internal_test.go index 09fc8001ab..0e46ccc0f7 100644 --- a/coderd/x/chatd/quickgen_internal_test.go +++ b/coderd/x/chatd/quickgen_internal_test.go @@ -3,11 +3,14 @@ package chatd import ( "context" "encoding/json" + "net/http" + "net/http/httptest" "strings" "testing" "time" "charm.land/fantasy" + fantasyopenaicompat "charm.land/fantasy/providers/openaicompat" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" @@ -667,6 +670,100 @@ func TestFallbackTurnStatusLabel(t *testing.T) { } } +func TestGenerateStructuredTitleWithUsage_OpenAICompatibleRequiredToolChoice(t *testing.T) { + t.Parallel() + + server, requests := newOpenAICompatStructuredOutputServer(t, "propose_title", `{"title":"Failed workspace logs"}`) + model := openAICompatTestModel(t, server.URL) + + title, _, err := generateStructuredTitleWithUsage( + t.Context(), + model, + titleGenerationPrompt, + "summarize failed workspace build logs", + ) + require.NoError(t, err) + require.Equal(t, "Failed workspace logs", title) + + body := testutil.TryReceive(t.Context(), t, requests) + require.Equal(t, "required", body["tool_choice"]) +} + +func newOpenAICompatStructuredOutputServer( + t *testing.T, + toolName string, + arguments string, +) (*httptest.Server, <-chan map[string]any) { + t.Helper() + + requests := make(chan map[string]any, 10) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + requests <- body + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "chatcmpl-structured-output", + "object": "chat.completion", + "created": time.Now().Unix(), + "model": "anthropic/claude-4-5-sonnet", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_structured_output", + "type": "function", + "function": map[string]any{ + "name": toolName, + "arguments": arguments, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + }) + })) + t.Cleanup(server.Close) + return server, requests +} + +func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel { + t.Helper() + + model, err := chatprovider.ModelFromConfig( + fantasyopenaicompat.Name, + "anthropic/claude-4-5-sonnet", + chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyopenaicompat.Name: "test-key", + }, + BaseURLByProvider: map[string]string{ + fantasyopenaicompat.Name: baseURL, + }, + }, + chatprovider.UserAgent(), + nil, + nil, + ) + require.NoError(t, err) + return model +} + func TestGenerateStructuredTurnStatusLabel(t *testing.T) { t.Parallel() @@ -682,11 +779,26 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { }, } - label, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done") + label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") require.NoError(t, err) require.Equal(t, "Submitted PR", label) }) + t.Run("sends required tool_choice to openai-compatible provider", func(t *testing.T) { + t.Parallel() + + server, requests := newOpenAICompatStructuredOutputServer(t, "propose_turn_status_label", `{"label":"Submitted PR"}`) + model := openAICompatTestModel(t, server.URL) + + label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") + require.NoError(t, err) + require.Equal(t, "Submitted PR", label) + require.Len(t, requests, 1) + + body := testutil.TryReceive(t.Context(), t, requests) + require.Equal(t, "required", body["tool_choice"]) + }) + t.Run("rejects narrative label", func(t *testing.T) { t.Parallel() @@ -698,7 +810,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { }, } - _, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done") + _, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done") require.ErrorContains(t, err, "generated turn status label was invalid") }) @@ -706,7 +818,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { t.Parallel() model := &chattest.FakeModel{} - _, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, " ") + _, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, " ") require.ErrorContains(t, err, "turn status label input was empty") }) }