diff --git a/coderd/chatd/chattest/openai.go b/coderd/chatd/chattest/openai.go index 6354e5caa5..6abdfefd5e 100644 --- a/coderd/chatd/chattest/openai.go +++ b/coderd/chatd/chattest/openai.go @@ -3,6 +3,7 @@ package chattest import ( "encoding/json" "fmt" + "log" "net/http" "net/http/httptest" "sync" @@ -183,7 +184,7 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req * http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError) return case hasStreaming: - s.writeChatCompletionsStreaming(w, resp.StreamingChunks) + writeChatCompletionsStreaming(w, req.Request, resp.StreamingChunks) default: s.writeChatCompletionsNonStreaming(w, resp.Response) } @@ -212,14 +213,13 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError) return case hasStreaming: - s.writeResponsesAPIStreaming(w, resp.StreamingChunks) + writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks) default: s.writeResponsesAPINonStreaming(w, resp.Response) } } -func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) { - _ = s // receiver unused but kept for consistency +func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -231,7 +231,21 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun return } - for chunk := range chunks { + for { + var chunk OpenAIChunk + var ok bool + select { + case <-r.Context().Done(): + log.Printf("writeChatCompletionsStreaming: request context canceled, stopping stream") + return + case chunk, ok = <-chunks: + if !ok { + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + return + } + } + choicesData := make([]map[string]interface{}, len(chunk.Choices)) for i, choice := range chunk.Choices { choiceData := map[string]interface{}{ @@ -278,13 +292,9 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun } flusher.Flush() } - - _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() } -func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) { - _ = s // receiver unused but kept for consistency +func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -298,7 +308,21 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks itemIDs := make(map[int]string) - for chunk := range chunks { + for { + var chunk OpenAIChunk + var ok bool + select { + case <-r.Context().Done(): + log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream") + return + case chunk, ok = <-chunks: + if !ok { + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + return + } + } + // Responses API sends one event per choice for outputIndex, choice := range chunk.Choices { if choice.Index != 0 { @@ -331,9 +355,6 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks flusher.Flush() } } - - _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() } func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {