diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index 76c447cec9..566cbd0675 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -1020,8 +1020,30 @@ func (p *Server) start(ctx context.Context) { } func (p *Server) processOnce(ctx context.Context) { - // Try to acquire a pending chat. - chat, err := p.db.AcquireChat(ctx, database.AcquireChatParams{ + // Bail out early if the server is shutting down. The main + // loop's select can randomly pick the ticker over ctx.Done(), + // so we must guard against acquiring a chat we cannot process. + if ctx.Err() != nil { + return + } + + // Try to acquire a pending chat. We detach from the server + // lifetime to prevent a phantom-acquire race: when the server + // context is canceled, the pq driver's watchCancel goroutine + // races with the actual query on the wire. The UPDATE can + // commit in Postgres (setting the chat to "running") before + // the cancel request arrives via a second TCP connection, yet + // the Go driver still returns context.Canceled to the caller + // because the awaitDone goroutine in database/sql closes the + // Rows before Scan reads them. This leaves the chat stuck as + // "running" with no goroutine to process it. Using a context + // that cannot be canceled ensures the driver sees the query + // result if Postgres executed it. + acquireCtx, acquireCancel := context.WithTimeout( + context.WithoutCancel(ctx), 10*time.Second, + ) + defer acquireCancel() + chat, err := p.db.AcquireChat(acquireCtx, database.AcquireChatParams{ StartedAt: time.Now(), WorkerID: p.workerID, }) @@ -1033,6 +1055,29 @@ func (p *Server) processOnce(ctx context.Context) { return } + // If the server context was canceled while we were acquiring, + // release the chat back to pending immediately so another + // replica can pick it up. + if ctx.Err() != nil { + releaseCtx, releaseCancel := context.WithTimeout( + context.WithoutCancel(ctx), 10*time.Second, + ) + defer releaseCancel() + _, updateErr := p.db.UpdateChatStatus(releaseCtx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusPending, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + if updateErr != nil { + p.logger.Error(ctx, "failed to release chat acquired during shutdown", + slog.F("chat_id", chat.ID), slog.Error(updateErr)) + } + return + } + // Process the chat (don't block the main loop). p.inflight.Add(1) go func() { diff --git a/coderd/chatd/chatd_test.go b/coderd/chatd/chatd_test.go index 7fac515ecd..de93e366a9 100644 --- a/coderd/chatd/chatd_test.go +++ b/coderd/chatd/chatd_test.go @@ -1571,6 +1571,12 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) var requestCount atomic.Int32 streamStarted := make(chan struct{}) openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + // Ignore non-streaming requests (e.g. title generation) so + // they don't interfere with the request counter used to + // coordinate the streaming chat flow. + if !req.Stream { + return chattest.OpenAINonStreamingResponse("shutdown-retry") + } if requestCount.Add(1) == 1 { chunks := make(chan chattest.OpenAIChunk, 1) go func() {