fix(coderd/x/chatd): harden openai-compatible chat calls (#25737)

OpenAI-compatible chat paths hit two provider compatibility issues. Some
compatible endpoints reject a named `tool_choice` when there is only one
tool, and Gemini's OpenAI-compatible endpoint requires thought
signatures on current-turn tool calls.

Centralize OpenAI-compatible request patches in the chat provider:
rewrite single named tool choices to `"required"`, and add the
documented dummy Google thought signature to the first tool call in each
current-turn tool step for Gemini routes. Vercel OpenAI-compatible
requests are left unchanged for the thought-signature patch.

> Mux created this PR on behalf of Mike.
This commit is contained in:
Michael Suchacz
2026-05-28 10:27:32 +02:00
committed by GitHub
parent cfa343e456
commit f529577bee
5 changed files with 695 additions and 3 deletions
@@ -1243,6 +1243,7 @@ func ModelFromConfig(
}
providerClient, err = fantasyopenai.New(options...)
case fantasyopenaicompat.Name:
httpClient = withOpenAICompatRequestPatches(httpClient, baseURL, modelID)
options := []fantasyopenaicompat.Option{
fantasyopenaicompat.WithAPIKey(apiKey),
fantasyopenaicompat.WithUserAgent(userAgent),
@@ -0,0 +1,237 @@
package chatprovider
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
)
// OpenAI-compatible providers share an API shape but differ in the exact JSON
// they accept. These patches adjust Fantasy's serialized request body at the
// transport boundary so higher-level generation code can stay provider agnostic.
//
// googleOpenAICompatDummyThoughtSignature is Google's documented last-resort
// bypass for callers that cannot preserve a real Gemini thought signature.
// See https://ai.google.dev/gemini-api/docs/thought-signatures.
const googleOpenAICompatDummyThoughtSignature = "skip_thought_signature_validator"
func withOpenAICompatRequestPatches(
client *http.Client,
baseURL string,
modelID string,
) *http.Client {
if client == nil {
client = &http.Client{}
} else {
clone := *client
client = &clone
}
client.Transport = &openAICompatRequestPatchTransport{
Base: client.Transport,
BaseURL: baseURL,
ModelID: modelID,
}
return client
}
type openAICompatRequestPatchTransport struct {
Base http.RoundTripper
// BaseURL is the configured provider base URL, used to detect direct Gemini endpoints.
BaseURL string
// ModelID is the configured model ID, used to detect Gemini routes through Coder AI Bridge.
ModelID string
}
func (t *openAICompatRequestPatchTransport) RoundTrip(req *http.Request) (*http.Response, error) {
base := t.base()
if !shouldPatchOpenAICompatRequest(req) {
return base.RoundTrip(req)
}
body, err := io.ReadAll(req.Body)
closeErr := req.Body.Close()
if err != nil {
return nil, err
}
if closeErr != nil {
return nil, closeErr
}
patched := patchOpenAICompatChatCompletionsBody(body, t.BaseURL, t.ModelID)
patchedReq := req.Clone(req.Context())
patchedReq.Body = io.NopCloser(bytes.NewReader(patched))
patchedReq.ContentLength = int64(len(patched))
patchedReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(patched)), nil
}
return base.RoundTrip(patchedReq)
}
func (t *openAICompatRequestPatchTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func shouldPatchOpenAICompatRequest(req *http.Request) bool {
return req != nil &&
req.Method == http.MethodPost &&
req.Body != nil &&
strings.HasSuffix(req.URL.Path, "/chat/completions")
}
func patchOpenAICompatChatCompletionsBody(body []byte, baseURL string, modelID string) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
changed := rewriteOpenAICompatSingleToolChoice(payload)
if shouldAddGoogleOpenAICompatThoughtSignatures(baseURL, modelID) {
changed = addGoogleOpenAICompatThoughtSignatures(payload) || changed
}
if !changed {
return body
}
patched, err := json.Marshal(payload)
if err != nil {
return body
}
return patched
}
// rewriteOpenAICompatSingleToolChoice replaces a single named tool choice with
// "required" because some compatible endpoints reject the named object form.
func rewriteOpenAICompatSingleToolChoice(payload map[string]any) bool {
tools, ok := payload["tools"].([]any)
if !ok || len(tools) != 1 {
return false
}
tool, ok := tools[0].(map[string]any)
if !ok {
return false
}
function, ok := tool["function"].(map[string]any)
if !ok {
return false
}
toolName, _ := function["name"].(string)
if toolName == "" {
return false
}
toolChoice, ok := payload["tool_choice"].(map[string]any)
if !ok {
return false
}
if toolType, _ := toolChoice["type"].(string); toolType != "function" {
return false
}
choiceFunction, ok := toolChoice["function"].(map[string]any)
if !ok {
return false
}
choiceName, _ := choiceFunction["name"].(string)
if choiceName != toolName {
return false
}
payload["tool_choice"] = "required"
return true
}
// shouldAddGoogleOpenAICompatThoughtSignatures detects direct Gemini OpenAI
// endpoints and Coder AI Bridge Gemini routes. Other gateways, such as Vercel,
// keep their own provider-specific compatibility behavior.
func shouldAddGoogleOpenAICompatThoughtSignatures(baseURL string, modelID string) bool {
parsed, err := url.Parse(baseURL)
if err != nil {
return false
}
host := strings.ToLower(parsed.Hostname())
path := strings.ToLower(parsed.EscapedPath())
if host == "generativelanguage.googleapis.com" && strings.Contains(path, "/openai") {
return true
}
return host == "coder-aibridge" && isGeminiModelID(modelID)
}
func isGeminiModelID(modelID string) bool {
modelID = strings.ToLower(strings.TrimSpace(modelID))
return strings.HasPrefix(modelID, "gemini-") || strings.Contains(modelID, "/gemini-")
}
// addGoogleOpenAICompatThoughtSignatures adds a dummy thought signature to the
// first tool call on each assistant tool-call message in the latest user turn.
// Gemini validates tool-call history with thought signatures, but
// OpenAI-compatible serialization can drop the original provider metadata.
func addGoogleOpenAICompatThoughtSignatures(payload map[string]any) bool {
messages, ok := payload["messages"].([]any)
if !ok {
return false
}
currentTurnStart := -1
for i, raw := range messages {
message, ok := raw.(map[string]any)
if !ok {
continue
}
if role, _ := message["role"].(string); role == "user" {
currentTurnStart = i
}
}
if currentTurnStart == -1 {
return false
}
changed := false
for _, raw := range messages[currentTurnStart+1:] {
message, ok := raw.(map[string]any)
if !ok || !isOpenAICompatAssistantRole(message["role"]) {
continue
}
toolCalls, ok := message["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
continue
}
firstToolCall, ok := toolCalls[0].(map[string]any)
if !ok {
continue
}
if ensureGoogleOpenAICompatThoughtSignature(firstToolCall) {
changed = true
}
}
return changed
}
func isOpenAICompatAssistantRole(role any) bool {
roleValue, _ := role.(string)
return roleValue == "assistant" || roleValue == "model"
}
func ensureGoogleOpenAICompatThoughtSignature(toolCall map[string]any) bool {
extraContent, _ := toolCall["extra_content"].(map[string]any)
google, _ := extraContent["google"].(map[string]any)
if signature, _ := google["thought_signature"].(string); signature != "" {
return false
}
if extraContent == nil {
extraContent = map[string]any{}
toolCall["extra_content"] = extraContent
}
if google == nil {
google = map[string]any{}
extraContent["google"] = google
}
google["thought_signature"] = googleOpenAICompatDummyThoughtSignature
return true
}
@@ -0,0 +1,156 @@
//nolint:testpackage // These tests cover unexported request-patch guards.
package chatprovider
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestPatchOpenAICompatChatCompletionsBody_Guards(t *testing.T) {
t.Parallel()
t.Run("leaves multi tool specific choice unchanged", func(t *testing.T) {
t.Parallel()
payload := map[string]any{
"tools": []any{
functionTool("first_tool"),
functionTool("second_tool"),
},
"tool_choice": map[string]any{
"type": "function",
"function": map[string]any{
"name": "first_tool",
},
},
}
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model")
body := decodeJSONMap(t, patched)
toolChoice, ok := body["tool_choice"].(map[string]any)
require.True(t, ok)
function, ok := toolChoice["function"].(map[string]any)
require.True(t, ok)
require.Equal(t, "first_tool", function["name"])
})
t.Run("leaves string tool choice unchanged", func(t *testing.T) {
t.Parallel()
payload := map[string]any{
"tools": []any{functionTool("first_tool")},
"tool_choice": "auto",
}
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "http://example.com/v1", "test-model")
body := decodeJSONMap(t, patched)
require.Equal(t, "auto", body["tool_choice"])
})
t.Run("leaves Gemini assistant history without a user turn unchanged", func(t *testing.T) {
t.Parallel()
payload := map[string]any{
"messages": []any{
map[string]any{
"role": "assistant",
"tool_calls": []any{
functionToolCall("call_without_user", "history_tool"),
},
},
},
}
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
body := decodeJSONMap(t, patched)
messages := body["messages"].([]any)
require.Empty(t, googleThoughtSignature(t, messages[0], 0))
})
t.Run("preserves existing Gemini thought signature", func(t *testing.T) {
t.Parallel()
payload := map[string]any{
"messages": []any{
map[string]any{"role": "user", "content": "current turn"},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_with_signature",
"type": "function",
"function": map[string]any{
"name": "signed_tool",
"arguments": `{}`,
},
"extra_content": map[string]any{
"google": map[string]any{
"thought_signature": "real-signature",
},
},
},
},
},
},
}
patched := patchOpenAICompatChatCompletionsBody(mustJSON(t, payload), "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
body := decodeJSONMap(t, patched)
messages := body["messages"].([]any)
require.Equal(t, "real-signature", googleThoughtSignature(t, messages[1], 0))
})
}
func functionTool(name string) map[string]any {
return map[string]any{
"type": "function",
"function": map[string]any{
"name": name,
},
}
}
func functionToolCall(id string, name string) map[string]any {
return map[string]any{
"id": id,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": `{}`,
},
}
}
func mustJSON(t *testing.T, payload map[string]any) []byte {
t.Helper()
body, err := json.Marshal(payload)
require.NoError(t, err)
return body
}
func decodeJSONMap(t *testing.T, body []byte) map[string]any {
t.Helper()
var payload map[string]any
require.NoError(t, json.Unmarshal(body, &payload))
return payload
}
func googleThoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string {
t.Helper()
message, ok := rawMessage.(map[string]any)
require.True(t, ok)
toolCalls, ok := message["tool_calls"].([]any)
require.True(t, ok)
require.Greater(t, len(toolCalls), toolCallIndex)
toolCall, ok := toolCalls[toolCallIndex].(map[string]any)
require.True(t, ok)
extraContent, _ := toolCall["extra_content"].(map[string]any)
google, _ := extraContent["google"].(map[string]any)
signature, _ := google["thought_signature"].(string)
return signature
}
@@ -0,0 +1,186 @@
package chatprovider_test
import (
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"charm.land/fantasy"
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
const dummyThoughtSignature = "skip_thought_signature_validator"
func TestModelFromConfig_GeminiOpenAICompatThoughtSignatures(t *testing.T) {
t.Parallel()
t.Run("Gemini endpoint receives current turn thought signature", func(t *testing.T) {
t.Parallel()
body := generateOpenAICompatRequest(t, "https://generativelanguage.googleapis.com/v1beta/openai/", "gemini-3.5-flash")
messages := body["messages"].([]any)
require.Empty(t, thoughtSignature(t, messages[1], 0))
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0))
require.Empty(t, thoughtSignature(t, messages[4], 1))
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[6], 0))
})
t.Run("Coder AI Bridge Gemini route receives current turn thought signature", func(t *testing.T) {
t.Parallel()
body := generateOpenAICompatRequest(t, "http://coder-aibridge/v1", "gemini-3.5-flash")
messages := body["messages"].([]any)
require.Equal(t, dummyThoughtSignature, thoughtSignature(t, messages[4], 0))
})
t.Run("Vercel OpenAI-compatible Gemini route is unchanged", func(t *testing.T) {
t.Parallel()
body := generateOpenAICompatRequest(t, "https://gateway.vercel.ai/v1", "google/gemini-3.5-flash")
messages := body["messages"].([]any)
require.Empty(t, thoughtSignature(t, messages[4], 0))
})
}
func generateOpenAICompatRequest(t *testing.T, baseURL string, modelID string) map[string]any {
t.Helper()
transport := &captureChatCompletionTransport{}
model, err := chatprovider.ModelFromConfig(
fantasyopenaicompat.Name,
modelID,
chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
fantasyopenaicompat.Name: "test-key",
},
BaseURLByProvider: map[string]string{
fantasyopenaicompat.Name: baseURL,
},
},
chatprovider.UserAgent(),
nil,
&http.Client{Transport: transport},
)
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{
Prompt: geminiOpenAICompatToolPrompt(),
})
require.NoError(t, err)
require.NotNil(t, transport.body)
return transport.body
}
type captureChatCompletionTransport struct {
body map[string]any
}
func (ct *captureChatCompletionTransport) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
if strings.HasSuffix(req.URL.Path, "/chat/completions") {
ct.body = map[string]any{}
if err := json.Unmarshal(body, &ct.body); err != nil {
return nil, err
}
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Body: io.NopCloser(strings.NewReader(`{
"id":"chatcmpl-test",
"object":"chat.completion",
"created":0,
"model":"gemini-3.5-flash",
"choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}],
"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
}`)),
}, nil
}
func geminiOpenAICompatToolPrompt() []fantasy.Message {
return []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "previous turn"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.ToolCallPart{ToolCallID: "previous-call", ToolName: "previous_tool", Input: `{}`},
},
},
{
Role: fantasy.MessageRoleTool,
Content: []fantasy.MessagePart{
fantasy.ToolResultPart{
ToolCallID: "previous-call",
Output: fantasy.ToolResultOutputContentText{Text: `{}`},
},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "current turn"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.ToolCallPart{ToolCallID: "current-call-a", ToolName: "first_tool", Input: `{}`},
fantasy.ToolCallPart{ToolCallID: "current-call-b", ToolName: "parallel_tool", Input: `{}`},
},
},
{
Role: fantasy.MessageRoleTool,
Content: []fantasy.MessagePart{
fantasy.ToolResultPart{
ToolCallID: "current-call-a",
Output: fantasy.ToolResultOutputContentText{Text: `{}`},
},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.ToolCallPart{
ToolCallID: "current-call-c",
ToolName: "second_step_tool",
Input: `{}`,
},
},
},
}
}
func thoughtSignature(t *testing.T, rawMessage any, toolCallIndex int) string {
t.Helper()
message, ok := rawMessage.(map[string]any)
require.True(t, ok)
toolCalls, ok := message["tool_calls"].([]any)
require.True(t, ok)
require.Greater(t, len(toolCalls), toolCallIndex)
toolCall, ok := toolCalls[toolCallIndex].(map[string]any)
require.True(t, ok)
extraContent, _ := toolCall["extra_content"].(map[string]any)
google, _ := extraContent["google"].(map[string]any)
signature, _ := google["thought_signature"].(string)
return signature
}
+115 -3
View File
@@ -3,11 +3,14 @@ package chatd
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"charm.land/fantasy"
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -667,6 +670,100 @@ func TestFallbackTurnStatusLabel(t *testing.T) {
}
}
func TestGenerateStructuredTitleWithUsage_OpenAICompatibleRequiredToolChoice(t *testing.T) {
t.Parallel()
server, requests := newOpenAICompatStructuredOutputServer(t, "propose_title", `{"title":"Failed workspace logs"}`)
model := openAICompatTestModel(t, server.URL)
title, _, err := generateStructuredTitleWithUsage(
t.Context(),
model,
titleGenerationPrompt,
"summarize failed workspace build logs",
)
require.NoError(t, err)
require.Equal(t, "Failed workspace logs", title)
body := testutil.TryReceive(t.Context(), t, requests)
require.Equal(t, "required", body["tool_choice"])
}
func newOpenAICompatStructuredOutputServer(
t *testing.T,
toolName string,
arguments string,
) (*httptest.Server, <-chan map[string]any) {
t.Helper()
requests := make(chan map[string]any, 10)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
requests <- body
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "chatcmpl-structured-output",
"object": "chat.completion",
"created": time.Now().Unix(),
"model": "anthropic/claude-4-5-sonnet",
"choices": []map[string]any{
{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_structured_output",
"type": "function",
"function": map[string]any{
"name": toolName,
"arguments": arguments,
},
},
},
},
"finish_reason": "tool_calls",
},
},
"usage": map[string]any{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
})
}))
t.Cleanup(server.Close)
return server, requests
}
func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel {
t.Helper()
model, err := chatprovider.ModelFromConfig(
fantasyopenaicompat.Name,
"anthropic/claude-4-5-sonnet",
chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
fantasyopenaicompat.Name: "test-key",
},
BaseURLByProvider: map[string]string{
fantasyopenaicompat.Name: baseURL,
},
},
chatprovider.UserAgent(),
nil,
nil,
)
require.NoError(t, err)
return model
}
func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
t.Parallel()
@@ -682,11 +779,26 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
},
}
label, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done")
label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
require.NoError(t, err)
require.Equal(t, "Submitted PR", label)
})
t.Run("sends required tool_choice to openai-compatible provider", func(t *testing.T) {
t.Parallel()
server, requests := newOpenAICompatStructuredOutputServer(t, "propose_turn_status_label", `{"label":"Submitted PR"}`)
model := openAICompatTestModel(t, server.URL)
label, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
require.NoError(t, err)
require.Equal(t, "Submitted PR", label)
require.Len(t, requests, 1)
body := testutil.TryReceive(t.Context(), t, requests)
require.Equal(t, "required", body["tool_choice"])
})
t.Run("rejects narrative label", func(t *testing.T) {
t.Parallel()
@@ -698,7 +810,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
},
}
_, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, "done")
_, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, "done")
require.ErrorContains(t, err, "generated turn status label was invalid")
})
@@ -706,7 +818,7 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{}
_, err := generateStructuredTurnStatusLabel(context.Background(), model, turnStatusLabelPrompt, " ")
_, err := generateStructuredTurnStatusLabel(t.Context(), model, turnStatusLabelPrompt, " ")
require.ErrorContains(t, err, "turn status label input was empty")
})
}