From a33ca95df2148dfeac3d8d024785112e302df5b9 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 2 Mar 2026 13:00:21 -0500 Subject: [PATCH] fix(chatd): prevent chat re-acquisition during server shutdown (#22497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes https://github.com/coder/internal/issues/1371 ## Problem `TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica` flakes intermittently in CI. The observed failure is that the chat never reaches `pending` status after `serverA.Close()`. ## Root cause Race between context cancellation and the mock OpenAI server's stream completion marker. When `Close()` cancels the server context, the in-flight HTTP streaming request is canceled. The mock server's handler detects this via `req.Context().Done()` and closes its chunks channel. The mock's `writeChatCompletionsStreaming` then writes `data: [DONE]` — the SSE completion marker. On a loopback connection, this marker can reach the client **before** the client's HTTP transport honors the context cancellation. When this happens: 1. The client sees a successful stream completion (not an error) 2. `chatloop.Run` returns `nil` 3. `processChat` falls through without error → status stays `waiting` (the default) 4. The test expects `pending` → **flake** ## Fix Skip writing the `[DONE]` marker when the request context is already canceled, in both `writeChatCompletionsStreaming` and `writeResponsesAPIStreaming`. --- coderd/chatd/chattest/openai.go | 49 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/coderd/chatd/chattest/openai.go b/coderd/chatd/chattest/openai.go index 6354e5caa5..6abdfefd5e 100644 --- a/coderd/chatd/chattest/openai.go +++ b/coderd/chatd/chattest/openai.go @@ -3,6 +3,7 @@ package chattest import ( "encoding/json" "fmt" + "log" "net/http" "net/http/httptest" "sync" @@ -183,7 +184,7 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req * http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError) return case hasStreaming: - s.writeChatCompletionsStreaming(w, resp.StreamingChunks) + writeChatCompletionsStreaming(w, req.Request, resp.StreamingChunks) default: s.writeChatCompletionsNonStreaming(w, resp.Response) } @@ -212,14 +213,13 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError) return case hasStreaming: - s.writeResponsesAPIStreaming(w, resp.StreamingChunks) + writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks) default: s.writeResponsesAPINonStreaming(w, resp.Response) } } -func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) { - _ = s // receiver unused but kept for consistency +func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -231,7 +231,21 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun return } - for chunk := range chunks { + for { + var chunk OpenAIChunk + var ok bool + select { + case <-r.Context().Done(): + log.Printf("writeChatCompletionsStreaming: request context canceled, stopping stream") + return + case chunk, ok = <-chunks: + if !ok { + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + return + } + } + choicesData := make([]map[string]interface{}, len(chunk.Choices)) for i, choice := range chunk.Choices { choiceData := map[string]interface{}{ @@ -278,13 +292,9 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun } flusher.Flush() } - - _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() } -func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) { - _ = s // receiver unused but kept for consistency +func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -298,7 +308,21 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks itemIDs := make(map[int]string) - for chunk := range chunks { + for { + var chunk OpenAIChunk + var ok bool + select { + case <-r.Context().Done(): + log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream") + return + case chunk, ok = <-chunks: + if !ok { + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + flusher.Flush() + return + } + } + // Responses API sends one event per choice for outputIndex, choice := range chunk.Choices { if choice.Index != 0 { @@ -331,9 +355,6 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks flusher.Flush() } } - - _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") - flusher.Flush() } func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {