mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/x/chatd): harden openai-compatible chat calls (#25737)
OpenAI-compatible chat paths hit two provider compatibility issues. Some compatible endpoints reject a named `tool_choice` when there is only one tool, and Gemini's OpenAI-compatible endpoint requires thought signatures on current-turn tool calls. Centralize OpenAI-compatible request patches in the chat provider: rewrite single named tool choices to `"required"`, and add the documented dummy Google thought signature to the first tool call in each current-turn tool step for Gemini routes. Vercel OpenAI-compatible requests are left unchanged for the thought-signature patch. > Mux created this PR on behalf of Mike.
This commit is contained in:
@@ -1243,6 +1243,7 @@ func ModelFromConfig(
|
||||
}
|
||||
providerClient, err = fantasyopenai.New(options...)
|
||||
case fantasyopenaicompat.Name:
|
||||
httpClient = withOpenAICompatRequestPatches(httpClient, baseURL, modelID)
|
||||
options := []fantasyopenaicompat.Option{
|
||||
fantasyopenaicompat.WithAPIKey(apiKey),
|
||||
fantasyopenaicompat.WithUserAgent(userAgent),
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
package chatprovider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OpenAI-compatible providers share an API shape but differ in the exact JSON
|
||||
// they accept. These patches adjust Fantasy's serialized request body at the
|
||||
// transport boundary so higher-level generation code can stay provider agnostic.
|
||||
//
|
||||
// googleOpenAICompatDummyThoughtSignature is Google's documented last-resort
|
||||
// bypass for callers that cannot preserve a real Gemini thought signature.
|
||||
// See https://ai.google.dev/gemini-api/docs/thought-signatures.
|
||||
const googleOpenAICompatDummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
func withOpenAICompatRequestPatches(
|
||||
client *http.Client,
|
||||
baseURL string,
|
||||
modelID string,
|
||||
) *http.Client {
|
||||
if client == nil {
|
||||
client = &http.Client{}
|
||||
} else {
|
||||
clone := *client
|
||||
client = &clone
|
||||
}
|
||||
client.Transport = &openAICompatRequestPatchTransport{
|
||||
Base: client.Transport,
|
||||
BaseURL: baseURL,
|
||||
ModelID: modelID,
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
type openAICompatRequestPatchTransport struct {
|
||||
Base http.RoundTripper
|
||||
// BaseURL is the configured provider base URL, used to detect direct Gemini endpoints.
|
||||
BaseURL string
|
||||
// ModelID is the configured model ID, used to detect Gemini routes through Coder AI Bridge.
|
||||
ModelID string
|
||||
}
|
||||
|
||||
func (t *openAICompatRequestPatchTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
base := t.base()
|
||||
if !shouldPatchOpenAICompatRequest(req) {
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
closeErr := req.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if closeErr != nil {
|
||||
return nil, closeErr
|
||||
}
|
||||
|
||||
patched := patchOpenAICompatChatCompletionsBody(body, t.BaseURL, t.ModelID)
|
||||
patchedReq := req.Clone(req.Context())
|
||||
patchedReq.Body = io.NopCloser(bytes.NewReader(patched))
|
||||
patchedReq.ContentLength = int64(len(patched))
|
||||
patchedReq.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(patched)), nil
|
||||
}
|
||||
|
||||
return base.RoundTrip(patchedReq)
|
||||
}
|
||||
|
||||
func (t *openAICompatRequestPatchTransport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func shouldPatchOpenAICompatRequest(req *http.Request) bool {
|
||||
return req != nil &&
|
||||
req.Method == http.MethodPost &&
|
||||
req.Body != nil &&
|
||||
strings.HasSuffix(req.URL.Path, "/chat/completions")
|
||||
}
|
||||
|
||||
func patchOpenAICompatChatCompletionsBody(body []byte, baseURL string, modelID string) []byte {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
changed := rewriteOpenAICompatSingleToolChoice(payload)
|
||||
if shouldAddGoogleOpenAICompatThoughtSignatures(baseURL, modelID) {
|
||||
changed = addGoogleOpenAICompatThoughtSignatures(payload) || changed
|
||||
}
|
||||
if !changed {
|
||||
return body
|
||||
}
|
||||
|
||||
patched, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return patched
|
||||
}
|
||||
|
||||
// rewriteOpenAICompatSingleToolChoice replaces a single named tool choice with
|
||||
// "required" because some compatible endpoints reject the named object form.
|
||||
func rewriteOpenAICompatSingleToolChoice(payload map[string]any) bool {
|
||||
tools, ok := payload["tools"].([]any)
|
||||
if !ok || len(tools) != 1 {
|
||||
return false
|
||||
}
|
||||
tool, ok := tools[0].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
function, ok := tool["function"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
toolName, _ := function["name"].(string)
|
||||
if toolName == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
toolChoice, ok := payload["tool_choice"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if toolType, _ := toolChoice["type"].(string); toolType != "function" {
|
||||
return false
|
||||
}
|
||||
choiceFunction, ok := toolChoice["function"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
choiceName, _ := choiceFunction["name"].(string)
|
||||
if choiceName != toolName {
|
||||
return false
|
||||
}
|
||||
|
||||
payload["tool_choice"] = "required"
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldAddGoogleOpenAICompatThoughtSignatures detects direct Gemini OpenAI
|
||||
// endpoints and Coder AI Bridge Gemini routes. Other gateways, such as Vercel,
|
||||
// keep their own provider-specific compatibility behavior.
|
||||
func shouldAddGoogleOpenAICompatThoughtSignatures(baseURL string, modelID string) bool {
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := strings.ToLower(parsed.Hostname())
|
||||
path := strings.ToLower(parsed.EscapedPath())
|
||||
if host == "generativelanguage.googleapis.com" && strings.Contains(path, "/openai") {
|
||||
return true
|
||||
}
|
||||
return host == "coder-aibridge" && isGeminiModelID(modelID)
|
||||
}
|
||||
|
||||
func isGeminiModelID(modelID string) bool {
|
||||
modelID = strings.ToLower(strings.TrimSpace(modelID))
|
||||
return strings.HasPrefix(modelID, "gemini-") || strings.Contains(modelID, "/gemini-")
|
||||
}
|
||||
|
||||
// addGoogleOpenAICompatThoughtSignatures adds a dummy thought signature to the
|
||||
// first tool call on each assistant tool-call message in the latest user turn.
|
||||
// Gemini validates tool-call history with thought signatures, but
|
||||
// OpenAI-compatible serialization can drop the original provider metadata.
|
||||
func addGoogleOpenAICompatThoughtSignatures(payload map[string]any) bool {
|
||||
messages, ok := payload["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
currentTurnStart := -1
|
||||
for i, raw := range messages {
|
||||
message, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if role, _ := message["role"].(string); role == "user" {
|
||||
currentTurnStart = i
|
||||
}
|
||||
}
|
||||
|
||||
if currentTurnStart == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
changed := false
|
||||
for _, raw := range messages[currentTurnStart+1:] {
|
||||
message, ok := raw.(map[string]any)
|
||||
if !ok || !isOpenAICompatAssistantRole(message["role"]) {
|
||||
continue
|
||||
}
|
||||
toolCalls, ok := message["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
firstToolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ensureGoogleOpenAICompatThoughtSignature(firstToolCall) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func isOpenAICompatAssistantRole(role any) bool {
|
||||
roleValue, _ := role.(string)
|
||||
return roleValue == "assistant" || roleValue == "model"
|
||||
}
|
||||
|
||||
func ensureGoogleOpenAICompatThoughtSignature(toolCall map[string]any) bool {
|
||||
extraContent, _ := toolCall["extra_content"].(map[string]any)
|
||||
google, _ := extraContent["google"].(map[string]any)
|
||||
if signature, _ := google["thought_signature"].(string); signature != "" {
|
||||
return false
|
||||
}
|
||||
if extraContent == nil {
|
||||
extraContent = map[string]any{}
|
||||
toolCall["extra_content"] = extraContent
|
||||
}
|
||||
if google == nil {
|
||||
google = map[string]any{}
|
||||
extraContent["google"] = google
|
||||
}
|
||||
google["thought_signature"] = googleOpenAICompatDummyThoughtSignature
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
//nolint:testpackage // These tests cover unexported request-patch guards.
|
||||
package chatprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPatchOpenAICompatChatCompletionsBody_Guards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("leaves multi tool specific choice unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := map[string]any{
|
||||
"tools": []any{
|
||||
functionTool("first_tool"),
|
||||
functionTool("second_tool"),
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "first_tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model")
|
||||
body := decodeJSONMap(t, patched)
|
||||
toolChoice, ok := body["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
function, ok := toolChoice["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "first_tool", function["name"])
|
||||
})
|
||||
|
||||
t.Run("leaves string tool choice unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := map[string]any{
|
||||
"tools": []any{functionTool("first_tool")},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model")
|
||||
body := decodeJSONMap(t, patched)
|
||||
require.Equal(t, "auto", body["tool_choice"])
|
||||
})
|
||||
|
||||
t.Run("leaves Gemini assistant history without a user turn unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := map[string]any{
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
functionToolCall("call_without_user", "history_tool"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
|
||||
body := decodeJSONMap(t, patched)
|
||||
messages := body["messages"].([]any)
|
||||
require.Empty(t, googleThoughtSignature(t, messages[0], 0))
|
||||
})
|
||||
|
||||
t.Run("preserves existing Gemini thought signature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := map[string]any{
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "current turn"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_with_signature",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "signed_tool",
|
||||
"arguments": `{}`,
|
||||
},
|
||||
"extra_content": map[string]any{
|
||||
"google": map[string]any{
|
||||
"thought_signature": "real-signature",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
|
||||
body := decodeJSONMap(t, patched)
|
||||
messages := body["messages"].([]any)
|
||||
require.Equal(t, "real-signature", googleThoughtSignature(t, messages[1], 0))
|
||||
})
|
||||
}
|
||||
|
||||
func functionTool(name string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func functionToolCall(id string, name string) map[string]any {
|
||||
return map[string]any{
|
||||
"id": id,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": `{}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, payload map[string]any) []byte {
|
||||
t.Helper()
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
return body
|
||||
}
|
||||
|
||||
func decodeJSONMap(t *testing.T, body []byte) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &payload))
|
||||
return payload
|
||||
}
|
||||
|
||||
func googleThoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string {
|
||||
t.Helper()
|
||||
|
||||
message, ok := rawMessage.(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := message["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Greater(t, len(toolCalls), toolCallIndex)
|
||||
toolCall, ok := toolCalls[toolCallIndex].(map[string]any)
|
||||
require.True(t, ok)
|
||||
extraContent, _ := toolCall["extra_content"].(map[string]any)
|
||||
google, _ := extraContent["google"].(map[string]any)
|
||||
signature, _ := google["thought_signature"].(string)
|
||||
return signature
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package chatprovider_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
func TestModelFromConfig_GeminiOpenAICompatThoughtSignatures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Gemini endpoint receives current turn thought signature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := generateOpenAICompatRequest(t, "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
|
||||
messages := body["messages"].([]any)
|
||||
|
||||
require.Empty(t, thoughtSignature(t, messages[1], 0))
|
||||
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0))
|
||||
require.Empty(t, thoughtSignature(t, messages[4], 1))
|
||||
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[6], 0))
|
||||
})
|
||||
|
||||
t.Run("Coder AI Bridge Gemini route receives current turn thought signature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := generateOpenAICompatRequest(t, "http://coder-aibridge/v1", "gemini-3.5-flash")
|
||||
messages := body["messages"].([]any)
|
||||
|
||||
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0))
|
||||
})
|
||||
|
||||
t.Run("Vercel OpenAI-compatible Gemini route is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := generateOpenAICompatRequest(t, "https://gateway.vercel.ai/v1", "google/gemini-3.5-flash")
|
||||
messages := body["messages"].([]any)
|
||||
|
||||
require.Empty(t, thoughtSignature(t, messages[4], 0))
|
||||
})
|
||||
}
|
||||
|
||||
func generateOpenAICompatRequest(t *testing.T, baseURL string, modelID string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
transport := &captureChatCompletionTransport{}
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
fantasyopenaicompat.Name,
|
||||
modelID,
|
||||
chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenaicompat.Name: "test-key",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenaicompat.Name: baseURL,
|
||||
},
|
||||
},
|
||||
chatprovider.UserAgent(),
|
||||
nil,
|
||||
&http.Client{Transport: transport},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(t.Context(), fantasy.Call{
|
||||
Prompt: geminiOpenAICompatToolPrompt(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, transport.body)
|
||||
return transport.body
|
||||
}
|
||||
|
||||
type captureChatCompletionTransport struct {
|
||||
body map[string]any
|
||||
}
|
||||
|
||||
func (ct *captureChatCompletionTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
if strings.HasSuffix(req.URL.Path, "/chat/completions") {
|
||||
ct.body = map[string]any{}
|
||||
if err := json.Unmarshal(body, &ct.body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"id":"chatcmpl-test",
|
||||
"object":"chat.completion",
|
||||
"created":0,
|
||||
"model":"gemini-3.5-flash",
|
||||
"choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}],
|
||||
"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
||||
}`)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func geminiOpenAICompatToolPrompt() []fantasy.Message {
|
||||
return []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "previous turn"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolCallPart{ToolCallID: "previous-call", ToolName: "previous_tool", Input: `{}`},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolResultPart{
|
||||
ToolCallID: "previous-call",
|
||||
Output: fantasy.ToolResultOutputContentText{Text: `{}`},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "current turn"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolCallPart{ToolCallID: "current-call-a", ToolName: "first_tool", Input: `{}`},
|
||||
fantasy.ToolCallPart{ToolCallID: "current-call-b", ToolName: "parallel_tool", Input: `{}`},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolResultPart{
|
||||
ToolCallID: "current-call-a",
|
||||
Output: fantasy.ToolResultOutputContentText{Text: `{}`},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolCallPart{
|
||||
ToolCallID: "current-call-c",
|
||||
ToolName: "second_step_tool",
|
||||
Input: `{}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func thoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string {
|
||||
t.Helper()
|
||||
message, ok := rawMessage.(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := message["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Greater(t, len(toolCalls), toolCallIndex)
|
||||
toolCall, ok := toolCalls[toolCallIndex].(map[string]any)
|
||||
require.True(t, ok)
|
||||
extraContent, _ := toolCall["extra_content"].(map[string]any)
|
||||
google, _ := extraContent["google"].(map[string]any)
|
||||
signature, _ := google["thought_signature"].(string)
|
||||
return signature
|
||||
}
|
||||
@@ -3,11 +3,14 @@ package chatd
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -667,6 +670,100 @@ func TestFallbackTurnStatusLabel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateStructuredTitleWithUsage_OpenAICompatibleRequiredToolChoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, requests := newOpenAICompatStructuredOutputServer(t, "propose_title", `{"title":"Failed workspace logs"}`)
|
||||
model := openAICompatTestModel(t, server.URL)
|
||||
|
||||
title, _, err := generateStructuredTitleWithUsage(
|
||||
t.Context(),
|
||||
model,
|
||||
titleGenerationPrompt,
|
||||
"summarize failed workspace build logs",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Failed workspace logs", title)
|
||||
|
||||
body := testutil.TryReceive(t.Context(), t, requests)
|
||||
require.Equal(t, "required", body["tool_choice"])
|
||||
}
|
||||
|
||||
func newOpenAICompatStructuredOutputServer(
|
||||
t *testing.T,
|
||||
toolName string,
|
||||
arguments string,
|
||||
) (*httptest.Server, <-chan map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
requests := make(chan map[string]any, 10)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
requests <- body
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chatcmpl-structured-output",
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": "anthropic/claude-4-5-sonnet",
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_structured_output",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": arguments,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
return server, requests
|
||||
}
|
||||
|
||||
func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel {
|
||||
t.Helper()
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
fantasyopenaicompat.Name,
|
||||
"anthropic/claude-4-5-sonnet",
|
||||
chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyopenaicompat.Name: "test-key",
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyopenaicompat.Name: baseURL,
|
||||
},
|
||||
},
|
||||
chatprovider.UserAgent(),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return model
|
||||
}
|
||||
|
||||
func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -682,11 +779,26 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
label, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done")
|
||||
label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Submitted PR", label)
|
||||
})
|
||||
|
||||
t.Run("sends required tool_choice to openai-compatible provider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, requests := newOpenAICompatStructuredOutputServer(t, "propose_turn_status_label", `{"label":"Submitted PR"}`)
|
||||
model := openAICompatTestModel(t, server.URL)
|
||||
|
||||
label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Submitted PR", label)
|
||||
require.Len(t, requests, 1)
|
||||
|
||||
body := testutil.TryReceive(t.Context(), t, requests)
|
||||
require.Equal(t, "required", body["tool_choice"])
|
||||
})
|
||||
|
||||
t.Run("rejects narrative label", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -698,7 +810,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done")
|
||||
_, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
|
||||
require.ErrorContains(t, err, "generated turn status label was invalid")
|
||||
})
|
||||
|
||||
@@ -706,7 +818,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{}
|
||||
_, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, " ")
|
||||
_, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, " ")
|
||||
require.ErrorContains(t, err, "turn status label input was empty")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user