diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 8c8e20abe2..1f88828c6a 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -4244,6 +4244,25 @@ func shouldClearRetryPhaseForStatus(status codersdk.ChatStatus) bool { } } +func (p *Server) clearProvisionalStreamParts(chatID uuid.UUID) { + val, ok := p.chatStreams.Load(chatID) + if !ok { + return + } + rs, ok := val.(*chatStreamState) + if !ok { + return + } + + // Streamed parts are provisional until a durable message commits + // them. A retry rolls back the failed attempt before replacement + // parts are streamed. + rs.mu.Lock() + rs.buffer = nil + rs.resetDropCounters() + rs.mu.Unlock() +} + func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEvent) { state := p.getOrCreateStreamState(chatID) state.mu.Lock() @@ -7796,14 +7815,7 @@ func (p *Server) runChat( classified chatretry.ClassifiedError, delay time.Duration, ) { - if val, ok := p.chatStreams.Load(chat.ID); ok { - if rs, ok := val.(*chatStreamState); ok { - rs.mu.Lock() - rs.buffer = nil - rs.resetDropCounters() - rs.mu.Unlock() - } - } + p.clearProvisionalStreamParts(chat.ID) logger.Warn(ctx, "retrying LLM stream", slog.F("attempt", attempt), slog.F("delay", delay.String()), diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index bbe63b8613..45a2940882 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -2332,6 +2332,51 @@ func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) { requireNoStreamEvent(t, events, 200*time.Millisecond) } +func TestSubscribeDoesNotReplayFailedAttemptPartsAfterRetry(t *testing.T) { + t.Parallel() + + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + chatID := uuid.New() + chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} + + gomock.InOrder( + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), + db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }).Return(nil, nil), + db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), + ) + + server := newBufferedSubscribeTestServer(t, db, chatID) + + server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("failed partial")) + server.clearProvisionalStreamParts(chatID) + server.publishRetry(chatID, newTestRetryPayload()) + server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered")) + + snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) + require.True(t, ok) + defer cancel() + + requireNoSnapshotRetryEvent(t, snapshot) + partEvent := requireSnapshotMessagePartEvent(t, snapshot) + require.Equal(t, "retry recovered", partEvent.MessagePart.Part.Text) + for _, event := range snapshot { + if event.Type != codersdk.ChatStreamEventTypeMessagePart { + continue + } + require.NotEqual(t, "failed partial", event.MessagePart.Part.Text) + } + requireNoStreamEvent(t, events, 200*time.Millisecond) +} + func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chaterror/classify.go b/coderd/x/chatd/chaterror/classify.go index 926b058a3b..c02cf71f1b 100644 --- a/coderd/x/chatd/chaterror/classify.go +++ b/coderd/x/chatd/chaterror/classify.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "golang.org/x/net/http2" + "github.com/coder/coder/v2/codersdk" ) @@ -33,6 +35,10 @@ type ClassifiedError struct { ChainBroken bool } +// http2PeerResetCause mirrors golang.org/x/net/http2's unexported +// errFromPeer message. +const http2PeerResetCause = "received from peer" + const responsesAPIDiagnosticMessage = "The chat continuation failed due to an " + "internal state mismatch. This is not a configuration or billing issue." @@ -188,15 +194,22 @@ func Classify(err error) ClassifiedError { return classified } + retryableHTTP2StreamReset, hasHTTP2StreamReset := classifyHTTP2StreamReset(err) deadline := errors.Is(err, context.DeadlineExceeded) || strings.Contains(lower, "context deadline exceeded") overloadedMatch := statusCode == 529 || containsAny(lower, overloadedPatterns...) authStrong := statusCode == 401 || containsAny(lower, authStrongPatterns...) configMatch := containsAny(lower, configPatterns...) authWeak := statusCode == 403 || containsAny(lower, authWeakPatterns...) rateLimitMatch := statusCode == 429 || containsAny(lower, rateLimitPatterns...) + timeoutPatternMatch := containsAny(lower, timeoutPatterns...) + if hasHTTP2StreamReset && !retryableHTTP2StreamReset { + // A typed HTTP/2 stream error gives us the reset code. Trust it + // over broader string fallbacks so protocol bugs do not retry. + timeoutPatternMatch = false + } timeoutMatch := deadline || statusCode == 408 || statusCode == 502 || statusCode == 503 || statusCode == 504 || - containsAny(lower, timeoutPatterns...) + retryableHTTP2StreamReset || timeoutPatternMatch genericRetryableMatch := statusCode == 500 || containsAny(lower, genericRetryablePatterns...) // Config signals should beat ambiguous wrapper signals so @@ -269,6 +282,46 @@ func Classify(err error) ClassifiedError { }) } +func classifyHTTP2StreamReset(err error) (retryable bool, found bool) { + streamErr, ok := findHTTP2StreamError(err) + if !ok { + return false, false + } + if !isPeerHTTP2StreamError(streamErr) { + return false, true + } + return isRetryableHTTP2StreamCode(streamErr.Code), true +} + +func findHTTP2StreamError(err error) (http2.StreamError, bool) { + var streamErr http2.StreamError + if errors.As(err, &streamErr) { + return streamErr, true + } + var streamErrPtr *http2.StreamError + if errors.As(err, &streamErrPtr) && streamErrPtr != nil { + return *streamErrPtr, true + } + return http2.StreamError{}, false +} + +func isPeerHTTP2StreamError(streamErr http2.StreamError) bool { + return streamErr.Cause != nil && streamErr.Cause.Error() == http2PeerResetCause +} + +func isRetryableHTTP2StreamCode(code http2.ErrCode) bool { + switch code { + case http2.ErrCodeNo, + http2.ErrCodeInternal, + http2.ErrCodeRefusedStream, + http2.ErrCodeCancel, + http2.ErrCodeEnhanceYourCalm: + return true + default: + return false + } +} + func streamIncompleteClassification( lowerMessage string, provider string, diff --git a/coderd/x/chatd/chaterror/classify_test.go b/coderd/x/chatd/chaterror/classify_test.go index a599e158ae..6fd036cae5 100644 --- a/coderd/x/chatd/chaterror/classify_test.go +++ b/coderd/x/chatd/chaterror/classify_test.go @@ -10,6 +10,7 @@ import ( "charm.land/fantasy" "github.com/stretchr/testify/require" + "golang.org/x/net/http2" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" @@ -332,6 +333,11 @@ func TestClassify_PatternCoverage(t *testing.T) { {name: "GOAWAYLiteral", err: "goaway", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, {name: "HTTP2StreamClosedLiteral", err: "http2: stream closed", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, {name: "UseOfClosedNetworkConnectionLiteral", err: "use of closed network connection", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, + {name: "HTTP2InternalErrorReceivedFromPeerLiteral", err: "internal_error; received from peer", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, + {name: "HTTP2RefusedStreamReceivedFromPeerLiteral", err: "refused_stream; received from peer", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, + {name: "HTTP2CancelReceivedFromPeerLiteral", err: "cancel; received from peer", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, + {name: "HTTP2EnhanceYourCalmReceivedFromPeerLiteral", err: "enhance_your_calm; received from peer", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, + {name: "HTTP2NoErrorReceivedFromPeerLiteral", err: "no_error; received from peer", wantKind: codersdk.ChatErrorKindTimeout, wantRetry: true}, {name: "AuthenticationLiteral", err: "authentication", wantKind: codersdk.ChatErrorKindAuth, wantRetry: false}, {name: "UnauthorizedLiteral", err: "unauthorized", wantKind: codersdk.ChatErrorKindAuth, wantRetry: false}, {name: "InvalidAPIKeyLiteral", err: "invalid api key", wantKind: codersdk.ChatErrorKindAuth, wantRetry: false}, @@ -429,6 +435,10 @@ func TestClassify_HTTP2TransportErrors(t *testing.T) { name: "HTTP2StreamClosed", err: "http2: stream closed", }, + { + name: "HTTP2PeerInternalStreamReset", + err: "stream error: stream ID 455; INTERNAL_ERROR; received from peer", + }, { name: "UseOfClosedNetworkConnectionOnPOST", err: `Post "https://example.com/v1/messages": use of closed network connection`, @@ -487,6 +497,12 @@ func TestClassify_HTTP2TransportErrors(t *testing.T) { provider: "openai", wantMessage: "OpenAI is temporarily unavailable.", }, + { + name: "AnthropicPeerInternalStreamReset", + err: `stream response: Post "https://api.anthropic.com/v1/messages": stream error: stream ID 455; INTERNAL_ERROR; received from peer`, + provider: "anthropic", + wantMessage: "Anthropic is temporarily unavailable.", + }, { name: "GoogleGOAWAY", err: `stream response: Post "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent": http2: server sent GOAWAY and closed the connection`, @@ -508,6 +524,208 @@ func TestClassify_HTTP2TransportErrors(t *testing.T) { } } +func TestClassify_HTTP2StreamErrorValues(t *testing.T) { + t.Parallel() + + peerReset := func(code http2.ErrCode) http2.StreamError { + return http2.StreamError{ + StreamID: 455, + Code: code, + Cause: xerrors.New("received from peer"), + } + } + + retryable := []struct { + name string + err error + want chaterror.ClassifiedError + }{ + { + name: "Internal", + err: peerReset(http2.ErrCodeInternal), + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "RefusedStream", + err: peerReset(http2.ErrCodeRefusedStream), + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "CancelPointer", + err: &http2.StreamError{ + StreamID: 455, + Code: http2.ErrCodeCancel, + Cause: xerrors.New("received from peer"), + }, + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "EnhanceYourCalm", + err: peerReset(http2.ErrCodeEnhanceYourCalm), + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "NoError", + err: peerReset(http2.ErrCodeNo), + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + } + + for _, tt := range retryable { + t.Run("Retryable/"+tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, chaterror.Classify(tt.err)) + }) + } + + localNonRetryable := []struct { + name string + err error + }{ + { + name: "CancelWithoutPeerCause", + err: http2.StreamError{ + StreamID: 455, + Code: http2.ErrCodeCancel, + }, + }, + { + name: "InternalWithLocalCause", + err: http2.StreamError{ + StreamID: 455, + Code: http2.ErrCodeInternal, + Cause: xerrors.New("local transport reset"), + }, + }, + } + for _, tt := range localNonRetryable { + t.Run("NonRetryable/"+tt.name, func(t *testing.T) { + t.Parallel() + classified := chaterror.Classify(tt.err) + require.Equal(t, codersdk.ChatErrorKindGeneric, classified.Kind) + require.False(t, classified.Retryable) + }) + } + + nonRetryable := []struct { + name string + code http2.ErrCode + }{ + {name: "Protocol", code: http2.ErrCodeProtocol}, + {name: "FlowControl", code: http2.ErrCodeFlowControl}, + {name: "FrameSize", code: http2.ErrCodeFrameSize}, + {name: "Compression", code: http2.ErrCodeCompression}, + } + for _, tt := range nonRetryable { + t.Run("NonRetryable/"+tt.name, func(t *testing.T) { + t.Parallel() + classified := chaterror.Classify(peerReset(tt.code)) + require.Equal(t, codersdk.ChatErrorKindGeneric, classified.Kind) + require.False(t, classified.Retryable) + }) + } +} + +func TestClassify_HTTP2StreamIDDoesNotBecomeStatusCode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want chaterror.ClassifiedError + }{ + { + name: "RetryableInternalWithAuthLikeStreamID", + err: http2.StreamError{ + StreamID: 401, + Code: http2.ErrCodeInternal, + Cause: xerrors.New("received from peer"), + }, + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "NonRetryableProtocolWithTimeoutLikeStreamID", + err: http2.StreamError{ + StreamID: 503, + Code: http2.ErrCodeProtocol, + Cause: xerrors.New("received from peer"), + }, + want: chaterror.ClassifiedError{ + Message: "The chat request failed unexpectedly.", + Kind: codersdk.ChatErrorKindGeneric, + }, + }, + { + name: "StringFallbackInternalWithAuthLikeStreamID", + err: xerrors.New("stream error: stream ID 401; INTERNAL_ERROR; received from peer"), + want: chaterror.ClassifiedError{ + Message: "The AI provider is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Retryable: true, + }, + }, + { + name: "StringProtocolWithTimeoutLikeStreamID", + err: xerrors.New("stream error: stream ID 503; PROTOCOL_ERROR; received from peer"), + want: chaterror.ClassifiedError{ + Message: "The chat request failed unexpectedly.", + Kind: codersdk.ChatErrorKindGeneric, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, chaterror.Classify(tt.err)) + }) + } +} + +func TestClassify_StatusCodeBeatsTypedHTTP2StreamError(t *testing.T) { + t.Parallel() + + err := xerrors.Errorf( + "provider returned status 401: %w", + http2.StreamError{ + StreamID: 455, + Code: http2.ErrCodeInternal, + Cause: xerrors.New("received from peer"), + }, + ) + + require.Equal(t, chaterror.ClassifiedError{ + Message: "Authentication with the AI provider failed. Check the API key, permissions, and billing settings.", + Kind: codersdk.ChatErrorKindAuth, + Retryable: false, + StatusCode: 401, + }, chaterror.Classify(err)) +} + // TestClassify_StatusCodeBeatsHTTP2Transport ensures explicit status // codes still win over the new HTTP/2 patterns. func TestClassify_StatusCodeBeatsHTTP2Transport(t *testing.T) { diff --git a/coderd/x/chatd/chaterror/signals.go b/coderd/x/chatd/chaterror/signals.go index 9f91e97e06..f23d86307e 100644 --- a/coderd/x/chatd/chaterror/signals.go +++ b/coderd/x/chatd/chaterror/signals.go @@ -48,6 +48,14 @@ var ( "goaway", "http2: stream closed", "use of closed network connection", + // Stringified HTTP/2 RST_STREAM errors. Classify uses + // typed http2.StreamError values when they survive wrapping; + // these patterns cover bridge layers that flatten errors. + "internal_error; received from peer", + "refused_stream; received from peer", + "cancel; received from peer", + "enhance_your_calm; received from peer", + "no_error; received from peer", } authStrongPatterns = []string{ "authentication", @@ -83,9 +91,7 @@ func extractStatusCode(lower string) int { return 0 } for _, loc := range standaloneStatusPattern.FindAllStringIndex(lower, -1) { - // Skip values in host:port text. A later standalone status code in the - // same message may still be valid, so keep scanning. - if loc[0] > 0 && lower[loc[0]-1] == ':' { + if shouldSkipStandaloneStatusMatch(lower, loc[0]) { continue } if code, err := strconv.Atoi(lower[loc[0]:loc[1]]); err == nil { @@ -96,6 +102,21 @@ func extractStatusCode(lower string) int { return 0 } +func shouldSkipStandaloneStatusMatch(lower string, start int) bool { + // Skip values in host:port text. A later standalone status code in the + // same message may still be valid, so keep scanning. + if start > 0 && lower[start-1] == ':' { + return true + } + + // Go's HTTP/2 stream reset errors include "stream ID N". Those IDs are + // not HTTP status codes, even when they happen to equal 401, 429, or 503. + prefix := strings.TrimRight(lower[:start], " \t\r\n") + prefix = strings.TrimRight(prefix, ":=") + prefix = strings.TrimRight(prefix, " \t\r\n") + return strings.HasSuffix(prefix, "stream id") +} + func detectProvider(lower string) string { for _, hint := range providerHints { if containsAny(lower, hint.patterns...) { diff --git a/coderd/x/chatd/chaterror/signals_test.go b/coderd/x/chatd/chaterror/signals_test.go index 799bc1033d..4d79ded548 100644 --- a/coderd/x/chatd/chaterror/signals_test.go +++ b/coderd/x/chatd/chaterror/signals_test.go @@ -27,6 +27,9 @@ func TestExtractStatusCode(t *testing.T) { {name: "PortNumberHostIsNotStatus", input: "proxy.internal:502 unreachable", want: 0}, {name: "PortNumberDialIsNotStatus", input: "dial tcp 172.16.0.5:429: refused", want: 0}, {name: "PortThenRealStatusReturnsRealStatus", input: "proxy at 10.0.0.1:500 returned 503", want: 503}, + {name: "HTTP2StreamIDIsNotStatus", input: "stream error: stream ID 401; INTERNAL_ERROR; received from peer", want: 0}, + {name: "HTTP2StreamIDWithPunctuationIsNotStatus", input: "stream error: stream ID: 503; PROTOCOL_ERROR; received from peer", want: 0}, + {name: "HTTP2StreamIDThenExplicitStatusReturnsStatus", input: "stream error: stream ID 455; status 503 from upstream", want: 503}, {name: "NoFabricatedOverloadStatus", input: "anthropic overloaded_error", want: 0}, {name: "NoFabricatedRateLimitStatus", input: "too many requests", want: 0}, {name: "NoFabricatedBadGatewayStatus", input: "bad gateway", want: 0},