fix(chatd): prevent chat re-acquisition during server shutdown (#22497)

Fixes https://github.com/coder/internal/issues/1371

## Problem

`TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica` flakes
intermittently in CI. The observed failure is that the chat never
reaches `pending` status after `serverA.Close()`.

## Root cause

Race between context cancellation and the mock OpenAI server's stream
completion marker.

When `Close()` cancels the server context, the in-flight HTTP streaming
request is canceled. The mock server's handler detects this via
`req.Context().Done()` and closes its chunks channel. The mock's
`writeChatCompletionsStreaming` then writes `data: [DONE]` — the SSE
completion marker. On a loopback connection, this marker can reach the
client **before** the client's HTTP transport honors the context
cancellation.

When this happens:
1. The client sees a successful stream completion (not an error)
2. `chatloop.Run` returns `nil`
3. `processChat` falls through without error → status stays `waiting`
(the default)
4. The test expects `pending` → **flake**

## Fix

Skip writing the `[DONE]` marker when the request context is already
canceled, in both `writeChatCompletionsStreaming` and
`writeResponsesAPIStreaming`.
This commit is contained in:
Kyle Carberry
2026-03-02 13:00:21 -05:00
committed by GitHub
parent 49aefdd973
commit a33ca95df2
+35 -14
View File
@@ -3,6 +3,7 @@ package chattest
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/http/httptest"
"sync"
@@ -183,7 +184,7 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeChatCompletionsStreaming(w, resp.StreamingChunks)
writeChatCompletionsStreaming(w, req.Request, resp.StreamingChunks)
default:
s.writeChatCompletionsNonStreaming(w, resp.Response)
}
@@ -212,14 +213,13 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
s.writeResponsesAPIStreaming(w, resp.StreamingChunks)
writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks)
default:
s.writeResponsesAPINonStreaming(w, resp.Response)
}
}
func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
@@ -231,7 +231,21 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun
return
}
for chunk := range chunks {
for {
var chunk OpenAIChunk
var ok bool
select {
case <-r.Context().Done():
log.Printf("writeChatCompletionsStreaming: request context canceled, stopping stream")
return
case chunk, ok = <-chunks:
if !ok {
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
}
choicesData := make([]map[string]interface{}, len(chunk.Choices))
for i, choice := range chunk.Choices {
choiceData := map[string]interface{}{
@@ -278,13 +292,9 @@ func (s *openAIServer) writeChatCompletionsStreaming(w http.ResponseWriter, chun
}
flusher.Flush()
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks <-chan OpenAIChunk) {
_ = s // receiver unused but kept for consistency
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
@@ -298,7 +308,21 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks
itemIDs := make(map[int]string)
for chunk := range chunks {
for {
var chunk OpenAIChunk
var ok bool
select {
case <-r.Context().Done():
log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream")
return
case chunk, ok = <-chunks:
if !ok {
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
}
// Responses API sends one event per choice
for outputIndex, choice := range chunk.Choices {
if choice.Index != 0 {
@@ -331,9 +355,6 @@ func (s *openAIServer) writeResponsesAPIStreaming(w http.ResponseWriter, chunks
flusher.Flush()
}
}
_, _ = fmt.Fprintf(w, "data: [DONE]\n\n")
flusher.Flush()
}
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {