mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
72689c2552
- Use t.Errorf in chattest non-streaming helpers so encoding failures fail the test - Thread testing.TB into writeResponsesAPIStreaming and log SSE write errors instead of silently dropping them - Bump createworkspace DB error log from Warn to Error - Use errors.Join for timeout + output error in execute.go
560 lines
16 KiB
Go
560 lines
16 KiB
Go
package chattest
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/openai/openai-go/v3/responses"
|
|
)
|
|
|
|
// OpenAIHandler handles OpenAI API requests and returns a response.
|
|
type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse
|
|
|
|
// OpenAIResponse represents a response to an OpenAI request.
|
|
// Either StreamingChunks or Response should be set, not both.
|
|
type OpenAIResponse struct {
|
|
StreamingChunks <-chan OpenAIChunk
|
|
Response *OpenAICompletion
|
|
Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON.
|
|
}
|
|
|
|
// OpenAIRequest represents an OpenAI chat completion request.
|
|
type OpenAIRequest struct {
|
|
*http.Request
|
|
Model string `json:"model"`
|
|
Messages []OpenAIMessage `json:"messages"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Tools []OpenAITool `json:"tools,omitempty"`
|
|
Prompt []interface{} `json:"prompt,omitempty"` // For responses API
|
|
// TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys.
|
|
Options map[string]interface{} `json:",inline"` //nolint:revive
|
|
}
|
|
|
|
// OpenAIMessage represents a message in an OpenAI request.
|
|
type OpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// OpenAIToolFunction represents the function definition inside a tool.
|
|
type OpenAIToolFunction struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
// OpenAITool represents a tool definition in an OpenAI request.
|
|
type OpenAITool struct {
|
|
Type string `json:"type"`
|
|
Function OpenAIToolFunction `json:"function"`
|
|
}
|
|
|
|
// OpenAIToolCallFunction represents the function details in a tool call.
|
|
type OpenAIToolCallFunction struct {
|
|
Name string `json:"name,omitempty"`
|
|
Arguments string `json:"arguments,omitempty"`
|
|
}
|
|
|
|
// OpenAIToolCall represents a tool call in a streaming chunk or completion.
|
|
type OpenAIToolCall struct {
|
|
ID string `json:"id,omitempty"`
|
|
Type string `json:"type,omitempty"`
|
|
Function OpenAIToolCallFunction `json:"function,omitempty"`
|
|
Index int `json:"index,omitempty"` // For streaming deltas
|
|
}
|
|
|
|
// OpenAIChunkChoice represents a choice in a streaming chunk.
|
|
type OpenAIChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta string `json:"delta,omitempty"`
|
|
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
|
FinishReason string `json:"finish_reason,omitempty"`
|
|
}
|
|
|
|
// OpenAIChunk represents a streaming chunk from OpenAI.
|
|
type OpenAIChunk struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []OpenAIChunkChoice `json:"choices"`
|
|
}
|
|
|
|
// OpenAICompletionChoice represents a choice in a completion response.
|
|
type OpenAICompletionChoice struct {
|
|
Index int `json:"index"`
|
|
Message OpenAIMessage `json:"message"`
|
|
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
// OpenAICompletionUsage represents usage information in a completion response.
|
|
type OpenAICompletionUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// OpenAICompletion represents a non-streaming OpenAI completion response.
|
|
type OpenAICompletion struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []OpenAICompletionChoice `json:"choices"`
|
|
Usage OpenAICompletionUsage `json:"usage"`
|
|
}
|
|
|
|
// openAIServer is a test server that mocks the OpenAI API.
|
|
type openAIServer struct {
|
|
mu sync.Mutex
|
|
t testing.TB
|
|
server *httptest.Server
|
|
handler OpenAIHandler
|
|
request *OpenAIRequest
|
|
}
|
|
|
|
// NewOpenAI creates a new OpenAI test server with a handler function.
|
|
// The handler is called for each request and should return either a streaming
|
|
// response (via channel) or a non-streaming response.
|
|
// Returns the base URL of the server.
|
|
func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
|
|
t.Helper()
|
|
|
|
s := &openAIServer{
|
|
t: t,
|
|
handler: handler,
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("POST /chat/completions", s.handleChatCompletions)
|
|
mux.HandleFunc("POST /responses", s.handleResponses)
|
|
|
|
s.server = httptest.NewServer(mux)
|
|
|
|
t.Cleanup(func() {
|
|
s.server.Close()
|
|
})
|
|
|
|
return s.server.URL
|
|
}
|
|
|
|
func (s *openAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
var req OpenAIRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
req.Request = r
|
|
|
|
s.mu.Lock()
|
|
s.request = &req
|
|
s.mu.Unlock()
|
|
|
|
resp := s.handler(&req)
|
|
s.writeChatCompletionsResponse(w, &req, resp)
|
|
}
|
|
|
|
func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
|
var req OpenAIRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
req.Request = r
|
|
|
|
s.mu.Lock()
|
|
s.request = &req
|
|
s.mu.Unlock()
|
|
|
|
resp := s.handler(&req)
|
|
s.writeResponsesAPIResponse(w, &req, resp)
|
|
}
|
|
|
|
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
|
if resp.Error != nil {
|
|
writeErrorResponse(s.t, w, resp.Error)
|
|
return
|
|
}
|
|
|
|
hasStreaming := resp.StreamingChunks != nil
|
|
hasNonStreaming := resp.Response != nil
|
|
|
|
switch {
|
|
case hasStreaming && hasNonStreaming:
|
|
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
|
return
|
|
case !hasStreaming && !hasNonStreaming:
|
|
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
|
return
|
|
case req.Stream && !hasStreaming:
|
|
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
|
return
|
|
case !req.Stream && !hasNonStreaming:
|
|
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
|
return
|
|
case hasStreaming:
|
|
writeChatCompletionsStreaming(w, req.Request, resp.StreamingChunks)
|
|
default:
|
|
s.writeChatCompletionsNonStreaming(w, resp.Response)
|
|
}
|
|
}
|
|
|
|
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
|
if resp.Error != nil {
|
|
writeErrorResponse(s.t, w, resp.Error)
|
|
return
|
|
}
|
|
|
|
hasStreaming := resp.StreamingChunks != nil
|
|
hasNonStreaming := resp.Response != nil
|
|
|
|
switch {
|
|
case hasStreaming && hasNonStreaming:
|
|
http.Error(w, "handler returned both streaming and non-streaming responses", http.StatusInternalServerError)
|
|
return
|
|
case !hasStreaming && !hasNonStreaming:
|
|
http.Error(w, "handler returned empty response", http.StatusInternalServerError)
|
|
return
|
|
case req.Stream && !hasStreaming:
|
|
http.Error(w, "handler returned non-streaming response for streaming request", http.StatusInternalServerError)
|
|
return
|
|
case !req.Stream && !hasNonStreaming:
|
|
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
|
return
|
|
case hasStreaming:
|
|
writeResponsesAPIStreaming(s.t, w, req.Request, resp.StreamingChunks)
|
|
default:
|
|
s.writeResponsesAPINonStreaming(w, resp.Response)
|
|
}
|
|
}
|
|
|
|
func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
for {
|
|
var chunk OpenAIChunk
|
|
var ok bool
|
|
select {
|
|
case <-r.Context().Done():
|
|
log.Printf("writeChatCompletionsStreaming: request context canceled, stopping stream")
|
|
return
|
|
case chunk, ok = <-chunks:
|
|
if !ok {
|
|
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
return
|
|
}
|
|
}
|
|
|
|
choicesData := make([]map[string]interface{}, len(chunk.Choices))
|
|
for i, choice := range chunk.Choices {
|
|
choiceData := map[string]interface{}{
|
|
"index": choice.Index,
|
|
}
|
|
if choice.Delta != "" {
|
|
choiceData["delta"] = map[string]interface{}{
|
|
"content": choice.Delta,
|
|
}
|
|
}
|
|
if len(choice.ToolCalls) > 0 {
|
|
// Tool calls come in the delta
|
|
if choiceData["delta"] == nil {
|
|
choiceData["delta"] = make(map[string]interface{})
|
|
}
|
|
delta, ok := choiceData["delta"].(map[string]interface{})
|
|
if !ok {
|
|
delta = make(map[string]interface{})
|
|
choiceData["delta"] = delta
|
|
}
|
|
delta["tool_calls"] = choice.ToolCalls
|
|
}
|
|
if choice.FinishReason != "" {
|
|
choiceData["finish_reason"] = choice.FinishReason
|
|
}
|
|
choicesData[i] = choiceData
|
|
}
|
|
|
|
chunkData := map[string]interface{}{
|
|
"id": chunk.ID,
|
|
"object": chunk.Object,
|
|
"created": chunk.Created,
|
|
"model": chunk.Model,
|
|
"choices": choicesData,
|
|
}
|
|
|
|
chunkBytes, err := json.Marshal(chunkData)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
// writeSSEEvent marshals v as JSON and writes it as an SSE data
|
|
// frame. Returns any write error.
|
|
func writeSSEEvent(w http.ResponseWriter, v interface{}) error {
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return err
|
|
}
|
|
|
|
func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
itemIDs := make(map[int]string)
|
|
|
|
for {
|
|
var chunk OpenAIChunk
|
|
var ok bool
|
|
select {
|
|
case <-r.Context().Done():
|
|
log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream")
|
|
return
|
|
case chunk, ok = <-chunks:
|
|
if !ok {
|
|
// Emit Responses API lifecycle events so
|
|
// the fantasy client closes open text
|
|
// blocks and persists the step content.
|
|
for outputIndex, itemID := range itemIDs {
|
|
if err := writeSSEEvent(w, responses.ResponseTextDoneEvent{
|
|
ItemID: itemID,
|
|
OutputIndex: int64(outputIndex),
|
|
}); err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to write ResponseTextDoneEvent: %v", err)
|
|
return
|
|
}
|
|
if err := writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
|
|
OutputIndex: int64(outputIndex),
|
|
Item: responses.ResponseOutputItemUnion{
|
|
ID: itemID,
|
|
Type: "message",
|
|
},
|
|
}); err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemDoneEvent: %v", err)
|
|
return
|
|
}
|
|
}
|
|
if err := writeSSEEvent(w, responses.ResponseCompletedEvent{}); err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to write ResponseCompletedEvent: %v", err)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
return
|
|
}
|
|
}
|
|
|
|
// Responses API sends one event per choice
|
|
for outputIndex, choice := range chunk.Choices {
|
|
if choice.Index != 0 {
|
|
outputIndex = choice.Index
|
|
}
|
|
itemID, found := itemIDs[outputIndex]
|
|
if !found {
|
|
itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8])
|
|
itemIDs[outputIndex] = itemID
|
|
|
|
// Emit response.output_item.added so the
|
|
// fantasy client triggers TextStart.
|
|
if err := writeSSEEvent(w, responses.ResponseOutputItemAddedEvent{
|
|
OutputIndex: int64(outputIndex),
|
|
Item: responses.ResponseOutputItemUnion{
|
|
ID: itemID,
|
|
Type: "message",
|
|
},
|
|
}); err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemAddedEvent: %v", err)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
|
|
chunkData := map[string]interface{}{
|
|
"type": "response.output_text.delta",
|
|
"item_id": itemID,
|
|
"output_index": outputIndex,
|
|
"created": chunk.Created,
|
|
"model": chunk.Model,
|
|
"content_index": 0,
|
|
"delta": choice.Delta,
|
|
}
|
|
|
|
chunkBytes, err := json.Marshal(chunkData)
|
|
if err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to marshal chunk data: %v", err)
|
|
return
|
|
}
|
|
|
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
|
t.Logf("writeResponsesAPIStreaming: failed to write chunk data: %v", err)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
s.t.Errorf("writeChatCompletionsNonStreaming: failed to encode response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
|
// Convert all choices to output format
|
|
outputs := make([]map[string]interface{}, len(resp.Choices))
|
|
for i, choice := range resp.Choices {
|
|
outputs[i] = map[string]interface{}{
|
|
"id": uuid.New().String(),
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": []map[string]interface{}{
|
|
{
|
|
"type": "output_text",
|
|
"text": choice.Message.Content,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"id": resp.ID,
|
|
"object": "response",
|
|
"created": resp.Created,
|
|
"model": resp.Model,
|
|
"output": outputs,
|
|
"usage": resp.Usage,
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
s.t.Errorf("writeResponsesAPINonStreaming: failed to encode response: %v", err)
|
|
}
|
|
}
|
|
|
|
// OpenAIStreamingResponse creates a streaming response from chunks.
|
|
func OpenAIStreamingResponse(chunks ...OpenAIChunk) OpenAIResponse {
|
|
ch := make(chan OpenAIChunk, len(chunks))
|
|
go func() {
|
|
for _, chunk := range chunks {
|
|
ch <- chunk
|
|
}
|
|
close(ch)
|
|
}()
|
|
return OpenAIResponse{StreamingChunks: ch}
|
|
}
|
|
|
|
// OpenAINonStreamingResponse creates a non-streaming response with the given text.
|
|
func OpenAINonStreamingResponse(text string) OpenAIResponse {
|
|
return OpenAIResponse{
|
|
Response: &OpenAICompletion{
|
|
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
|
Object: "chat.completion",
|
|
Created: time.Now().Unix(),
|
|
Model: "gpt-4",
|
|
Choices: []OpenAICompletionChoice{
|
|
{
|
|
Index: 0,
|
|
Message: OpenAIMessage{
|
|
Role: "assistant",
|
|
Content: text,
|
|
},
|
|
FinishReason: "stop",
|
|
},
|
|
},
|
|
Usage: OpenAICompletionUsage{
|
|
PromptTokens: 10,
|
|
CompletionTokens: 5,
|
|
TotalTokens: 15,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// OpenAITextChunks creates streaming chunks with text deltas.
|
|
// Each delta string becomes a separate chunk with a single choice.
|
|
// Returns a slice of chunks, one per delta, with each choice having its index (0, 1, 2, ...).
|
|
func OpenAITextChunks(deltas ...string) []OpenAIChunk {
|
|
if len(deltas) == 0 {
|
|
return nil
|
|
}
|
|
|
|
chunkID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8])
|
|
now := time.Now().Unix()
|
|
chunks := make([]OpenAIChunk, len(deltas))
|
|
|
|
for i, delta := range deltas {
|
|
chunks[i] = OpenAIChunk{
|
|
ID: chunkID,
|
|
Object: "chat.completion.chunk",
|
|
Created: now,
|
|
Model: "gpt-4",
|
|
Choices: []OpenAIChunkChoice{
|
|
{
|
|
Index: i,
|
|
Delta: delta,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
return chunks
|
|
}
|
|
|
|
// OpenAIToolCallChunk creates a streaming chunk with a tool call.
|
|
// Takes the tool name and arguments JSON string, creates a tool call for choice index 0.
|
|
func OpenAIToolCallChunk(toolName, arguments string) OpenAIChunk {
|
|
return OpenAIChunk{
|
|
ID: fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:8]),
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: "gpt-4",
|
|
Choices: []OpenAIChunkChoice{
|
|
{
|
|
Index: 0,
|
|
ToolCalls: []OpenAIToolCall{
|
|
{
|
|
Index: 0,
|
|
ID: fmt.Sprintf("call_%s", uuid.New().String()[:8]),
|
|
Type: "function",
|
|
Function: OpenAIToolCallFunction{
|
|
Name: toolName,
|
|
Arguments: arguments,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|