diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index f44556cd72..f34566454a 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1294,7 +1294,8 @@ CREATE TABLE chat_messages ( content_version smallint NOT NULL, total_cost_micros bigint, runtime_ms bigint, - deleted boolean DEFAULT false NOT NULL + deleted boolean DEFAULT false NOT NULL, + provider_response_id text ); CREATE SEQUENCE chat_messages_id_seq diff --git a/coderd/database/migrations/000450_chat_messages_provider_response_id.down.sql b/coderd/database/migrations/000450_chat_messages_provider_response_id.down.sql new file mode 100644 index 0000000000..177afb1a81 --- /dev/null +++ b/coderd/database/migrations/000450_chat_messages_provider_response_id.down.sql @@ -0,0 +1 @@ +ALTER TABLE chat_messages DROP COLUMN provider_response_id; diff --git a/coderd/database/migrations/000450_chat_messages_provider_response_id.up.sql b/coderd/database/migrations/000450_chat_messages_provider_response_id.up.sql new file mode 100644 index 0000000000..707a12735b --- /dev/null +++ b/coderd/database/migrations/000450_chat_messages_provider_response_id.up.sql @@ -0,0 +1 @@ +ALTER TABLE chat_messages ADD COLUMN provider_response_id TEXT; diff --git a/coderd/database/models.go b/coderd/database/models.go index 4b7feb6b02..c7c558650a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4229,6 +4229,7 @@ type ChatMessage struct { TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"` RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"` Deleted bool `db:"deleted" json:"deleted"` + ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"` } type ChatModelConfig struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index eb8b72bb6a..3543d1f8e7 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4623,7 +4623,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [ const getChatMessageByID = `-- name: GetChatMessageByID :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id FROM chat_messages WHERE @@ -4655,13 +4655,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ) return i, err } const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id FROM chat_messages WHERE @@ -4708,6 +4709,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ); err != nil { return nil, err } @@ -4724,7 +4726,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id FROM chat_messages WHERE @@ -4777,6 +4779,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, a &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ); err != nil { return nil, err } @@ -4809,7 +4812,7 @@ WITH latest_compressed_summary AS ( 1 ) SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id FROM chat_messages WHERE @@ -4880,6 +4883,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ); err != nil { return nil, err } @@ -5085,7 +5089,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id FROM chat_messages WHERE @@ -5127,6 +5131,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ) return i, err } @@ -5339,7 +5344,8 @@ INSERT INTO chat_messages ( context_limit, compressed, total_cost_micros, - runtime_ms + runtime_ms, + provider_response_id ) SELECT $1::uuid, @@ -5358,9 +5364,10 @@ SELECT NULLIF(UNNEST($14::bigint[]), 0), UNNEST($15::boolean[]), NULLIF(UNNEST($16::bigint[]), 0), - NULLIF(UNNEST($17::bigint[]), 0) + NULLIF(UNNEST($17::bigint[]), 0), + NULLIF(UNNEST($18::text[]), '') RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id ` type InsertChatMessagesParams struct { @@ -5381,6 +5388,7 @@ type InsertChatMessagesParams struct { Compressed []bool `db:"compressed" json:"compressed"` TotalCostMicros []int64 `db:"total_cost_micros" json:"total_cost_micros"` RuntimeMs []int64 `db:"runtime_ms" json:"runtime_ms"` + ProviderResponseID []string `db:"provider_response_id" json:"provider_response_id"` } func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error) { @@ -5402,6 +5410,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa pq.Array(arg.Compressed), pq.Array(arg.TotalCostMicros), pq.Array(arg.RuntimeMs), + pq.Array(arg.ProviderResponseID), ) if err != nil { return nil, err @@ -5431,6 +5440,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ); err != nil { return nil, err } @@ -5789,7 +5799,7 @@ SET WHERE id = $3::bigint RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id ` type UpdateChatMessageByIDParams struct { @@ -5822,6 +5832,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe &i.TotalCostMicros, &i.RuntimeMs, &i.Deleted, + &i.ProviderResponseID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 4a8af3921f..440eeddb70 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -241,7 +241,8 @@ INSERT INTO chat_messages ( context_limit, compressed, total_cost_micros, - runtime_ms + runtime_ms, + provider_response_id ) SELECT @chat_id::uuid, @@ -260,7 +261,8 @@ SELECT NULLIF(UNNEST(@context_limit::bigint[]), 0), UNNEST(@compressed::boolean[]), NULLIF(UNNEST(@total_cost_micros::bigint[]), 0), - NULLIF(UNNEST(@runtime_ms::bigint[]), 0) + NULLIF(UNNEST(@runtime_ms::bigint[]), 0), + NULLIF(UNNEST(@provider_response_id::text[]), '') RETURNING *; diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index f689a54652..ecb7589808 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -1200,6 +1200,7 @@ type chatMessage struct { contextLimit int64 totalCostMicros int64 runtimeMs int64 + providerResponseID string } func newChatMessage( @@ -1256,6 +1257,101 @@ func (m chatMessage) withRuntimeMs(ms int64) chatMessage { return m } +func (m chatMessage) withProviderResponseID(id string) chatMessage { + m.providerResponseID = id + return m +} + +// chainModeInfo holds the information needed to determine whether +// a follow-up turn can use OpenAI's previous_response_id chaining +// instead of replaying full conversation history. +type chainModeInfo struct { + // previousResponseID is the provider response ID from the last + // assistant message, if any. + previousResponseID string + // modelConfigID is the model configuration used to produce the + // assistant message referenced by previousResponseID. + modelConfigID uuid.UUID + // trailingUserCount is the number of contiguous user messages + // at the end of the conversation that form the current turn. + trailingUserCount int +} + +// resolveChainMode scans DB messages from the end to count trailing user +// messages for the current turn and detect whether the immediately +// preceding assistant/tool block can chain from a provider response ID. +func resolveChainMode(messages []database.ChatMessage) chainModeInfo { + var info chainModeInfo + i := len(messages) - 1 + for ; i >= 0; i-- { + if messages[i].Role == database.ChatMessageRoleUser { + info.trailingUserCount++ + continue + } + break + } + for ; i >= 0; i-- { + switch messages[i].Role { + case database.ChatMessageRoleAssistant: + if messages[i].ProviderResponseID.Valid && + messages[i].ProviderResponseID.String != "" { + info.previousResponseID = messages[i].ProviderResponseID.String + if messages[i].ModelConfigID.Valid { + info.modelConfigID = messages[i].ModelConfigID.UUID + } + return info + } + return info + case database.ChatMessageRoleTool: + continue + default: + return info + } + } + return info +} + +// filterPromptForChainMode keeps only system messages and the last +// trailingUserCount user messages from the prompt. Assistant and tool +// messages are dropped because the provider already has them via the +// previous_response_id chain. +func filterPromptForChainMode( + prompt []fantasy.Message, + trailingUserCount int, +) []fantasy.Message { + if trailingUserCount <= 0 { + return prompt + } + + totalUsers := 0 + for _, msg := range prompt { + if msg.Role == "user" { + totalUsers++ + } + } + + usersToSkip := totalUsers - trailingUserCount + if usersToSkip < 0 { + usersToSkip = 0 + } + + filtered := make([]fantasy.Message, 0, len(prompt)) + usersSeen := 0 + for _, msg := range prompt { + switch msg.Role { + case "system": + filtered = append(filtered, msg) + case "user": + usersSeen++ + if usersSeen > usersToSkip { + filtered = append(filtered, msg) + } + } + } + + return filtered +} + // appendChatMessage appends a single message to the batch insert params. func appendChatMessage( params *database.InsertChatMessagesParams, @@ -1277,6 +1373,7 @@ func appendChatMessage( params.Compressed = append(params.Compressed, msg.compressed) params.TotalCostMicros = append(params.TotalCostMicros, msg.totalCostMicros) params.RuntimeMs = append(params.RuntimeMs, msg.runtimeMs) + params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID) } func insertUserMessageAndSetPending( @@ -2824,6 +2921,7 @@ func (p *Server) runChat( if err := g.Wait(); err != nil { return result, err } + chainInfo := resolveChainMode(messages) result.PushSummaryModel = model result.ProviderKeys = providerKeys // Fire title generation asynchronously so it doesn't block the @@ -3093,7 +3191,8 @@ func (p *Server) runChat( reasoningTokens, cacheCreationTokens, cacheReadTokens, ).withContextLimit(contextLimit). withTotalCostMicros(totalCostVal). - withRuntimeMs(runtimeMs)) + withRuntimeMs(runtimeMs). + withProviderResponseID(step.ProviderResponseID)) } for _, resultContent := range toolResultContents { @@ -3294,13 +3393,35 @@ func (p *Server) runChat( ), }) } + + providerOptions := chatprovider.ProviderOptionsFromChatModelConfig( + model, + callConfig.ProviderOptions, + ) + // When the OpenAI Responses API has store=true, the provider + // retains conversation history server-side. For follow-up turns, + // we set previous_response_id and send only system instructions + // plus the new user input, avoiding redundant replay of prior + // assistant and tool messages that the provider already has. + chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) && + chainInfo.previousResponseID != "" && + chainInfo.trailingUserCount > 0 && + chainInfo.modelConfigID == modelConfig.ID + if chainModeActive { + providerOptions = chatprovider.CloneWithPreviousResponseID( + providerOptions, + chainInfo.previousResponseID, + ) + prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount) + } + err = chatloop.Run(ctx, chatloop.RunOptions{ Model: model, Messages: prompt, Tools: tools, MaxSteps: maxChatSteps, ModelConfig: callConfig, - ProviderOptions: chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions), + ProviderOptions: providerOptions, ProviderTools: providerTools, ContextLimitFallback: modelConfigContextLimit, @@ -3348,8 +3469,17 @@ func (p *Server) runChat( if reloadUserPrompt != "" { reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, reloadUserPrompt) } + if chainModeActive { + reloadedPrompt = filterPromptForChainMode( + reloadedPrompt, + chainInfo.trailingUserCount, + ) + } return reloadedPrompt, nil }, + DisableChainMode: func() { + chainModeActive = false + }, OnRetry: func(attempt int, retryErr error, delay time.Duration) { if val, ok := p.chatStreams.Load(chat.ID); ok { diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index 38f7237ec0..1203e354eb 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -13,6 +13,7 @@ import ( "charm.land/fantasy" fantasyanthropic "charm.land/fantasy/providers/anthropic" + fantasyopenai "charm.land/fantasy/providers/openai" "charm.land/fantasy/schema" "golang.org/x/xerrors" @@ -39,9 +40,10 @@ var ErrInterrupted = xerrors.New("chat interrupted") // persistence layer is responsible for splitting these into // separate database messages by role. type PersistedStep struct { - Content []fantasy.Content - Usage fantasy.Usage - ContextLimit sql.NullInt64 + Content []fantasy.Content + Usage fantasy.Usage + ContextLimit sql.NullInt64 + ProviderResponseID string // Runtime is the wall-clock duration of this step, // covering LLM streaming, tool execution, and retries. // Zero indicates the duration was not measured (e.g. @@ -80,8 +82,9 @@ type RunOptions struct { role codersdk.ChatMessageRole, part codersdk.ChatMessagePart, ) - Compaction *CompactionOptions - ReloadMessages func(context.Context) ([]fantasy.Message, error) + Compaction *CompactionOptions + ReloadMessages func(context.Context) ([]fantasy.Message, error) + DisableChainMode func() // OnRetry is called before each retry attempt when the LLM // stream fails with a retryable error. It provides the attempt @@ -245,6 +248,18 @@ func Run(ctx context.Context, opts RunOptions) error { messages := opts.Messages var lastUsage fantasy.Usage var lastProviderMetadata fantasy.ProviderMetadata + needsFullHistoryReload := false + reloadFullHistory := func(stage string) error { + if opts.ReloadMessages == nil { + return nil + } + reloaded, err := opts.ReloadMessages(ctx) + if err != nil { + return xerrors.Errorf("reload messages %s: %w", stage, err) + } + messages = reloaded + return nil + } totalSteps := 0 // When totalSteps reaches MaxSteps the inner loop exits immediately @@ -368,10 +383,11 @@ func Run(ctx context.Context, opts RunOptions) error { // check and here, fall back to the interrupt-safe // path so partial content is not lost. if err := opts.PersistStep(ctx, PersistedStep{ - Content: result.content, - Usage: result.usage, - ContextLimit: contextLimit, - Runtime: time.Since(stepStart), + Content: result.content, + Usage: result.usage, + ContextLimit: contextLimit, + ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata), + Runtime: time.Since(stepStart), }); err != nil { if errors.Is(err, ErrInterrupted) { persistInterruptedStep(ctx, opts, &result) @@ -382,14 +398,41 @@ func Run(ctx context.Context, opts RunOptions) error { lastUsage = result.usage lastProviderMetadata = result.providerMetadata - // Append the step's response messages so that both - // inline and post-loop compaction see the full - // conversation including the latest assistant reply. + // When chain mode is active (PreviousResponseID set), exit + // it after persisting the first chained step. Continuation + // steps include tool-result messages, which fantasy rejects + // when previous_response_id is set, so we must leave chain + // mode and reload the full history before the next call. stepMessages := result.toResponseMessages() - messages = append(messages, stepMessages...) + if hasPreviousResponseID(opts.ProviderOptions) { + clearPreviousResponseID(opts.ProviderOptions) + if opts.DisableChainMode != nil { + opts.DisableChainMode() + } + switch { + case opts.ReloadMessages != nil: + if err := reloadFullHistory("after chain mode exit"); err != nil { + return err + } + needsFullHistoryReload = false + default: + messages = append(messages, stepMessages...) + needsFullHistoryReload = false + } + } else { + messages = append(messages, stepMessages...) + } + + if needsFullHistoryReload && !result.shouldContinue && + opts.ReloadMessages != nil { + if err := reloadFullHistory("before final compaction after chain mode exit"); err != nil { + return err + } + needsFullHistoryReload = false + } // Inline compaction. - if opts.Compaction != nil && opts.ReloadMessages != nil { + if !needsFullHistoryReload && opts.Compaction != nil && opts.ReloadMessages != nil { did, compactErr := tryCompact( ctx, opts.Model, @@ -405,14 +448,11 @@ func Run(ctx context.Context, opts RunOptions) error { if did { alreadyCompacted = true compactedOnFinalStep = true - reloaded, reloadErr := opts.ReloadMessages(ctx) - if reloadErr != nil { - return xerrors.Errorf("reload messages after compaction: %w", reloadErr) + if err := reloadFullHistory("after compaction"); err != nil { + return err } - messages = reloaded } } - if !result.shouldContinue { stoppedByModel = true break @@ -423,9 +463,16 @@ func Run(ctx context.Context, opts RunOptions) error { compactedOnFinalStep = false } + if needsFullHistoryReload && stoppedByModel && opts.ReloadMessages != nil { + if err := reloadFullHistory("before post-run compaction after chain mode exit"); err != nil { + return err + } + needsFullHistoryReload = false + } + // Post-run compaction safety net: if we never compacted // during the loop, try once at the end. - if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil { + if !needsFullHistoryReload && !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil { did, err := tryCompact( ctx, opts.Model, @@ -973,6 +1020,85 @@ func addAnthropicPromptCaching(messages []fantasy.Message) { } } +// hasPreviousResponseID checks whether the provider options contain +// an OpenAI Responses entry with a non-empty PreviousResponseID. +func hasPreviousResponseID(providerOptions fantasy.ProviderOptions) bool { + if providerOptions == nil { + return false + } + + for _, entry := range providerOptions { + if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok { + return options.PreviousResponseID != nil && + *options.PreviousResponseID != "" + } + } + + return false +} + +// clearPreviousResponseID removes PreviousResponseID from the OpenAI +// Responses provider options entry, if present. +func clearPreviousResponseID(providerOptions fantasy.ProviderOptions) { + if providerOptions == nil { + return + } + + for _, entry := range providerOptions { + if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok { + options.PreviousResponseID = nil + } + } +} + +// extractOpenAIResponseID extracts the OpenAI Responses API response +// ID from provider metadata. Returns an empty string if no OpenAI +// Responses metadata is present. +func extractOpenAIResponseID(metadata fantasy.ProviderMetadata) string { + if len(metadata) == 0 { + return "" + } + + for _, entry := range metadata { + if providerMetadata, ok := entry.(*fantasyopenai.ResponsesProviderMetadata); ok && providerMetadata != nil { + return providerMetadata.ResponseID + } + } + + return "" +} + +// extractOpenAIResponseIDIfStored returns the OpenAI response ID +// only when the provider options indicate store=true. Response IDs +// from store=false turns are not persisted server-side and cannot +// be used for chaining. +func extractOpenAIResponseIDIfStored( + providerOptions fantasy.ProviderOptions, + metadata fantasy.ProviderMetadata, +) string { + if !isResponsesStoreEnabled(providerOptions) { + return "" + } + + return extractOpenAIResponseID(metadata) +} + +// isResponsesStoreEnabled checks whether the OpenAI Responses +// provider options explicitly enable store=true. +func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool { + if providerOptions == nil { + return false + } + + for _, entry := range providerOptions { + if options, ok := entry.(*fantasyopenai.ResponsesProviderOptions); ok { + return options.Store != nil && *options.Store + } + } + + return false +} + func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 { if len(metadata) == 0 { return sql.NullInt64{} diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index 892096c1a9..ee693e43ab 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -1063,6 +1063,46 @@ func ProviderOptionsFromChatModelConfig( return result } +// IsResponsesStoreEnabled checks if the OpenAI Responses provider +// options are present and have Store set to true. When true, the +// provider stores conversation history server-side, enabling +// follow-up chaining via PreviousResponseID. +func IsResponsesStoreEnabled(opts fantasy.ProviderOptions) bool { + if opts == nil { + return false + } + raw, ok := opts[fantasyopenai.Name] + if !ok { + return false + } + respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions) + if !ok || respOpts == nil { + return false + } + return respOpts.Store != nil && *respOpts.Store +} + +// CloneWithPreviousResponseID shallow-clones the provider options +// map and the OpenAI Responses entry, setting PreviousResponseID +// on the clone. The original map and entry are not mutated. +func CloneWithPreviousResponseID( + opts fantasy.ProviderOptions, + previousResponseID string, +) fantasy.ProviderOptions { + cloned := make(fantasy.ProviderOptions, len(opts)) + for k, v := range opts { + cloned[k] = v + } + if raw, ok := cloned[fantasyopenai.Name]; ok { + if respOpts, ok := raw.(*fantasyopenai.ResponsesProviderOptions); ok && respOpts != nil { + clone := *respOpts + clone.PreviousResponseID = &previousResponseID + cloned[fantasyopenai.Name] = &clone + } + } + return cloned +} + func openAIProviderOptionsFromChatConfig( model fantasy.LanguageModel, options *codersdk.ChatModelOpenAIProviderOptions, diff --git a/coderd/x/chatd/chattest/openai.go b/coderd/x/chatd/chattest/openai.go index 6f19e08afe..1e99628dda 100644 --- a/coderd/x/chatd/chattest/openai.go +++ b/coderd/x/chatd/chattest/openai.go @@ -6,12 +6,12 @@ import ( "log" "net/http" "net/http/httptest" + "sort" "sync" "testing" "time" "github.com/google/uuid" - "github.com/openai/openai-go/v3/responses" ) // OpenAIHandler handles OpenAI API requests and returns a response. @@ -22,17 +22,37 @@ type OpenAIHandler func(req *OpenAIRequest) OpenAIResponse type OpenAIResponse struct { StreamingChunks <-chan OpenAIChunk Response *OpenAICompletion + Reasoning *OpenAIReasoningItem + WebSearch *OpenAIWebSearchCall + ResponseID string // If set, used as the response ID in streamed events; otherwise auto-generated. Error *ErrorResponse // If set, server returns this HTTP error instead of streaming/JSON. } +// OpenAIReasoningItem configures a streamed reasoning output item for the +// Responses API test server. +type OpenAIReasoningItem struct { + ID string `json:"id,omitempty"` + Summary string `json:"summary,omitempty"` + EncryptedContent string `json:"encrypted_content,omitempty"` +} + +// OpenAIWebSearchCall configures a streamed web_search_call output item for the +// Responses API test server. +type OpenAIWebSearchCall struct { + ID string `json:"id,omitempty"` + Query string `json:"query,omitempty"` +} + // OpenAIRequest represents an OpenAI chat completion request. type OpenAIRequest struct { *http.Request - Model string `json:"model"` - Messages []OpenAIMessage `json:"messages"` - Stream bool `json:"stream,omitempty"` - Tools []OpenAITool `json:"tools,omitempty"` - Prompt []interface{} `json:"prompt,omitempty"` // For responses API + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + Tools []OpenAITool `json:"tools,omitempty"` + Prompt []interface{} `json:"prompt,omitempty"` // For responses API + Store *bool `json:"store,omitempty"` + PreviousResponseID *string `json:"previous_response_id,omitempty"` // TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys. Options map[string]interface{} `json:",inline"` //nolint:revive } @@ -228,7 +248,7 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError) return case hasStreaming: - writeResponsesAPIStreaming(s.t, w, req.Request, resp.StreamingChunks) + writeResponsesAPIStreaming(s.t, w, req.Request, resp) default: s.writeResponsesAPINonStreaming(w, resp.Response) } @@ -309,18 +329,19 @@ func writeChatCompletionsStreaming(w http.ResponseWriter, r *http.Request, chunk } } -// writeSSEEvent marshals v as JSON and writes it as an SSE data -// frame. Returns any write error. -func writeSSEEvent(w http.ResponseWriter, v interface{}) error { +func writeNamedSSEEvent(w http.ResponseWriter, eventType string, v interface{}) error { data, err := json.Marshal(v) if err != nil { return err } + if _, err := fmt.Fprintf(w, "event: %s\n", eventType); err != nil { + return err + } _, err = fmt.Fprintf(w, "data: %s\n\n", data) return err } -func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) { +func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Request, resp OpenAIResponse) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -332,7 +353,170 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req return } + responseID := resp.ResponseID + if responseID == "" { + responseID = fmt.Sprintf("resp_%s", uuid.New().String()[:8]) + } + responseModel := "gpt-4" + sequenceNumber := int64(0) + textOffset := 0 itemIDs := make(map[int]string) + itemTexts := make(map[int]string) + + writeEvent := func(eventType string, payload map[string]interface{}) bool { + payload["type"] = eventType + payload["sequence_number"] = sequenceNumber + sequenceNumber++ + if err := writeNamedSSEEvent(w, eventType, payload); err != nil { + t.Logf("writeResponsesAPIStreaming: failed to write %s: %v", eventType, err) + return false + } + flusher.Flush() + return true + } + + if !writeEvent("response.created", map[string]interface{}{ + "response": map[string]interface{}{ + "id": responseID, + "object": "response", + "model": responseModel, + "status": "in_progress", + "output": []interface{}{}, + }, + }) { + return + } + + if resp.Reasoning != nil { + outputIndex := textOffset + reasoningID := resp.Reasoning.ID + if reasoningID == "" { + reasoningID = fmt.Sprintf("rs_%s", uuid.New().String()[:8]) + } + summary := resp.Reasoning.Summary + encryptedContent := resp.Reasoning.EncryptedContent + if encryptedContent == "" { + encryptedContent = "encrypted_data_here" + } + + if !writeEvent("response.output_item.added", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "reasoning", + "id": reasoningID, + "summary": []interface{}{}, + "encrypted_content": "", + }, + }) { + return + } + + if summary != "" { + if !writeEvent("response.reasoning_summary_part.added", map[string]interface{}{ + "item_id": reasoningID, + "output_index": outputIndex, + "summary_index": 0, + "part": map[string]interface{}{ + "type": "summary_text", + "text": "", + }, + }) { + return + } + if !writeEvent("response.reasoning_summary_text.added", map[string]interface{}{ + "item_id": reasoningID, + "output_index": outputIndex, + "summary_index": 0, + }) { + return + } + if !writeEvent("response.reasoning_summary_text.delta", map[string]interface{}{ + "item_id": reasoningID, + "output_index": outputIndex, + "summary_index": 0, + "delta": summary, + }) { + return + } + if !writeEvent("response.reasoning_summary_text.done", map[string]interface{}{ + "item_id": reasoningID, + "output_index": outputIndex, + "summary_index": 0, + "text": summary, + }) { + return + } + if !writeEvent("response.reasoning_summary_part.done", map[string]interface{}{ + "item_id": reasoningID, + "output_index": outputIndex, + "summary_index": 0, + "part": map[string]interface{}{ + "type": "summary_text", + "text": summary, + }, + }) { + return + } + } + + summaryItems := []interface{}{} + if summary != "" { + summaryItems = append(summaryItems, map[string]interface{}{ + "type": "summary_text", + "text": summary, + }) + } + if !writeEvent("response.output_item.done", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "reasoning", + "id": reasoningID, + "summary": summaryItems, + "encrypted_content": encryptedContent, + }, + }) { + return + } + textOffset++ + } + + if resp.WebSearch != nil { + outputIndex := textOffset + itemID := resp.WebSearch.ID + if itemID == "" { + itemID = fmt.Sprintf("ws_%s", uuid.New().String()[:8]) + } + query := resp.WebSearch.Query + if query == "" { + query = "latest AI news" + } + + if !writeEvent("response.output_item.added", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "web_search_call", + "id": itemID, + "status": "in_progress", + }, + }) { + return + } + if !writeEvent("response.output_item.done", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "web_search_call", + "id": itemID, + "status": "completed", + "action": map[string]interface{}{ + "type": "search", + "query": query, + }, + }, + }) { + return + } + textOffset++ + } for { var chunk OpenAIChunk @@ -341,85 +525,117 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req case <-r.Context().Done(): log.Printf("writeResponsesAPIStreaming: request context canceled, stopping stream") return - case chunk, ok = <-chunks: + case chunk, ok = <-resp.StreamingChunks: if !ok { - // Emit Responses API lifecycle events so - // the fantasy client closes open text - // blocks and persists the step content. - for outputIndex, itemID := range itemIDs { - if err := writeSSEEvent(w, responses.ResponseTextDoneEvent{ - ItemID: itemID, - OutputIndex: int64(outputIndex), - }); err != nil { - t.Logf("writeResponsesAPIStreaming: failed to write ResponseTextDoneEvent: %v", err) + indices := make([]int, 0, len(itemIDs)) + for outputIndex := range itemIDs { + indices = append(indices, outputIndex) + } + sort.Ints(indices) + for _, outputIndex := range indices { + itemID := itemIDs[outputIndex] + text := itemTexts[outputIndex] + if !writeEvent("response.output_text.done", map[string]interface{}{ + "item_id": itemID, + "output_index": outputIndex, + "content_index": 0, + "text": text, + "logprobs": []interface{}{}, + }) { return } - if err := writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{ - OutputIndex: int64(outputIndex), - Item: responses.ResponseOutputItemUnion{ - ID: itemID, - Type: "message", + if !writeEvent("response.content_part.done", map[string]interface{}{ + "item_id": itemID, + "output_index": outputIndex, + "content_index": 0, + "part": map[string]interface{}{ + "type": "output_text", + "text": text, }, - }); err != nil { - t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemDoneEvent: %v", err) + }) { + return + } + if !writeEvent("response.output_item.done", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "message", + "id": itemID, + "role": "assistant", + "status": "completed", + "content": []interface{}{ + map[string]interface{}{ + "type": "output_text", + "text": text, + }, + }, + }, + }) { return } } - if err := writeSSEEvent(w, responses.ResponseCompletedEvent{}); err != nil { - t.Logf("writeResponsesAPIStreaming: failed to write ResponseCompletedEvent: %v", err) + if !writeEvent("response.completed", map[string]interface{}{ + "response": map[string]interface{}{ + "id": responseID, + "object": "response", + "model": responseModel, + "status": "completed", + "output": []interface{}{}, + "usage": map[string]interface{}{}, + }, + }) { return } - flusher.Flush() return } } - // Responses API sends one event per choice + if chunk.Model != "" { + responseModel = chunk.Model + } + for outputIndex, choice := range chunk.Choices { if choice.Index != 0 { outputIndex = choice.Index } + outputIndex += textOffset itemID, found := itemIDs[outputIndex] if !found { itemID = fmt.Sprintf("msg_%s", uuid.New().String()[:8]) itemIDs[outputIndex] = itemID - - // Emit response.output_item.added so the - // fantasy client triggers TextStart. - if err := writeSSEEvent(w, responses.ResponseOutputItemAddedEvent{ - OutputIndex: int64(outputIndex), - Item: responses.ResponseOutputItemUnion{ - ID: itemID, - Type: "message", + if !writeEvent("response.output_item.added", map[string]interface{}{ + "output_index": outputIndex, + "item": map[string]interface{}{ + "type": "message", + "id": itemID, + "role": "assistant", + "status": "in_progress", + "content": []interface{}{}, }, - }); err != nil { - t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemAddedEvent: %v", err) + }) { + return + } + if !writeEvent("response.content_part.added", map[string]interface{}{ + "item_id": itemID, + "output_index": outputIndex, + "content_index": 0, + "part": map[string]interface{}{ + "type": "output_text", + "text": "", + }, + }) { return } - flusher.Flush() } - chunkData := map[string]interface{}{ - "type": "response.output_text.delta", + itemTexts[outputIndex] += choice.Delta + if !writeEvent("response.output_text.delta", map[string]interface{}{ "item_id": itemID, "output_index": outputIndex, - "created": chunk.Created, - "model": chunk.Model, "content_index": 0, "delta": choice.Delta, - } - - chunkBytes, err := json.Marshal(chunkData) - if err != nil { - t.Logf("writeResponsesAPIStreaming: failed to marshal chunk data: %v", err) + }) { return } - - if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil { - t.Logf("writeResponsesAPIStreaming: failed to write chunk data: %v", err) - return - } - flusher.Flush() } } } diff --git a/coderd/x/chatd/integration_test.go b/coderd/x/chatd/integration_test.go index 05a66e0712..01be17c4a6 100644 --- a/coderd/x/chatd/integration_test.go +++ b/coderd/x/chatd/integration_test.go @@ -1,14 +1,24 @@ package chatd_test import ( + "bytes" "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" "os" + "strconv" + "strings" + "sync" + "sync/atomic" "testing" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -587,3 +597,306 @@ func partTypeSet(parts []codersdk.ChatMessagePart) map[codersdk.ChatMessagePartT } return set } + +type openAIStoreMode string + +const ( + openAIStoreModeTrue openAIStoreMode = "store_true" + openAIStoreModeFalse openAIStoreMode = "store_false" +) + +func TestOpenAIReasoningWithWebSearchRoundTrip(t *testing.T) { + t.Parallel() + runOpenAIReasoningWithWebSearchRoundTripTest(t, openAIStoreModeTrue) +} + +func TestOpenAIReasoningWithWebSearchRoundTripStoreFalse(t *testing.T) { + t.Parallel() + runOpenAIReasoningWithWebSearchRoundTripTest(t, openAIStoreModeFalse) +} + +func runOpenAIReasoningWithWebSearchRoundTripTest(t *testing.T, storeMode openAIStoreMode) { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + store := storeMode == openAIStoreModeTrue + + type capturedOpenAIRequest struct { + Stream bool `json:"stream,omitempty"` + Store *bool `json:"store,omitempty"` + PreviousResponseID *string `json:"previous_response_id,omitempty"` + Prompt []interface{} `json:"input,omitempty"` + } + + var ( + streamRequestCount atomic.Int32 + firstReq *capturedOpenAIRequest + secondReq *capturedOpenAIRequest + mu sync.Mutex + ) + upstreamOpenAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("reasoning + web search title") + } + + switch req.Header.Get("X-Request-Ordinal") { + case "1": + return chattest.OpenAIResponse{ + ResponseID: "resp_first_test", + StreamingChunks: chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Here is what I found.")..., + ).StreamingChunks, + Reasoning: &chattest.OpenAIReasoningItem{ + Summary: "thinking about the question", + EncryptedContent: "encrypted_data_here", + }, + WebSearch: &chattest.OpenAIWebSearchCall{ + Query: "latest AI news", + }, + } + default: + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Follow-up answer.")..., + ) + } + }) + captureServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read OpenAI request body: %v", err) + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + _ = r.Body.Close() + + if r.URL.Path == "/responses" { + var captured capturedOpenAIRequest + if err := json.Unmarshal(body, &captured); err != nil { + t.Errorf("decode OpenAI request body: %v", err) + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + if captured.Stream { + requestCount := streamRequestCount.Add(1) + r.Header.Set("X-Request-Ordinal", strconv.Itoa(int(requestCount))) + + mu.Lock() + switch requestCount { + case 1: + firstReq = &captured + default: + secondReq = &captured + } + mu.Unlock() + } + } + + upstreamReq, err := http.NewRequestWithContext( + r.Context(), + r.Method, + upstreamOpenAIURL+r.URL.RequestURI(), + bytes.NewReader(body), + ) + if err != nil { + t.Errorf("create upstream OpenAI request: %v", err) + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + upstreamReq.Header = r.Header.Clone() + + resp, err := http.DefaultClient.Do(upstreamReq) + if err != nil { + t.Errorf("forward OpenAI request: %v", err) + http.Error(rw, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + for key, values := range resp.Header { + for _, value := range values { + rw.Header().Add(key, value) + } + } + rw.WriteHeader(resp.StatusCode) + if _, err := io.Copy(rw, resp.Body); err != nil { + t.Errorf("copy OpenAI response body: %v", err) + } + })) + t.Cleanup(captureServer.Close) + openAIURL := captureServer.URL + + deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)} + client := coderdtest.New(t, &coderdtest.Options{ + DeploymentValues: deploymentValues, + }) + _ = coderdtest.CreateFirstUser(t, client) + expClient := codersdk.NewExperimentalClient(client) + + _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai", + APIKey: "test-api-key", + BaseURL: openAIURL, + }) + require.NoError(t, err) + + contextLimit := int64(200000) + isDefault := true + reasoningEffort := "medium" + reasoningSummary := "auto" + _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai", + Model: "o4-mini", + ContextLimit: &contextLimit, + IsDefault: &isDefault, + ModelConfig: &codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + OpenAI: &codersdk.ChatModelOpenAIProviderOptions{ + Store: ptr.Ref(store), + ReasoningEffort: &reasoningEffort, + ReasoningSummary: &reasoningSummary, + WebSearchEnabled: ptr.Ref(true), + }, + }, + }, + }) + require.NoError(t, err) + + t.Logf("Creating chat with reasoning + web search query (store=%t)...", store) + chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "Search for the latest AI news and summarize it briefly.", + }}, + }) + require.NoError(t, err) + + events, closer, err := expClient.StreamChat(ctx, chat.ID, nil) + require.NoError(t, err) + defer closer.Close() + + waitForChatDone(ctx, t, events, "step 1") + + chatData, err := expClient.GetChat(ctx, chat.ID) + require.NoError(t, err) + chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil) + require.NoError(t, err) + require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status, + "chat should be in waiting status after step 1") + + assistantMsg := findAssistantWithText(t, chatMsgs.Messages) + require.NotNil(t, assistantMsg, + "expected an assistant message with text content after step 1") + + partTypes := partTypeSet(assistantMsg.Content) + require.Contains(t, partTypes, codersdk.ChatMessagePartTypeReasoning, + "assistant message should contain reasoning parts") + require.Contains(t, partTypes, codersdk.ChatMessagePartTypeToolCall, + "assistant message should contain a provider-executed web search tool call") + require.Contains(t, partTypes, codersdk.ChatMessagePartTypeToolResult, + "assistant message should contain a provider-executed web search tool result") + require.Contains(t, partTypes, codersdk.ChatMessagePartTypeText, + "assistant message should contain a text part") + + var foundReasoning, foundWebSearchCall, foundText bool + for _, part := range assistantMsg.Content { + switch part.Type { + case codersdk.ChatMessagePartTypeReasoning: + // fantasy emits a leading newline when the reasoning summary part is + // added, so match the persisted summary text after trimming whitespace. + if strings.TrimSpace(part.Text) == "thinking about the question" { + foundReasoning = true + } + case codersdk.ChatMessagePartTypeToolCall: + if part.ToolName == "web_search" { + require.True(t, part.ProviderExecuted, + "web search tool-call should be marked provider-executed") + foundWebSearchCall = true + } + case codersdk.ChatMessagePartTypeText: + if part.Text == "Here is what I found." { + foundText = true + } + } + } + require.True(t, foundReasoning, "expected reasoning summary text to be persisted") + require.True(t, foundWebSearchCall, "expected persisted web_search tool call") + require.True(t, foundText, "expected streamed assistant text to be persisted") + + t.Log("Sending follow-up message...") + _, err = expClient.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "What is the follow-up takeaway?", + }}, + }) + if !store && err != nil { + require.NotContains(t, err.Error(), + "Items are not persisted when store is set to false.", + "follow-up should reconstruct store=false responses without stale provider item IDs") + } + require.NoError(t, err) + + events2, closer2, err := expClient.StreamChat(ctx, chat.ID, nil) + require.NoError(t, err) + defer closer2.Close() + + waitForChatDone(ctx, t, events2, "step 2") + + chatData2, err := expClient.GetChat(ctx, chat.ID) + require.NoError(t, err) + chatMsgs2, err := expClient.GetChatMessages(ctx, chat.ID, nil) + require.NoError(t, err) + require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status, + "chat should be in waiting status after step 2") + require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages), + "follow-up should have added more messages") + require.NotNil(t, findLastAssistantWithText(t, chatMsgs2.Messages), + "expected an assistant message with text after the follow-up") + require.Equal(t, int32(2), streamRequestCount.Load(), + "expected exactly two streamed OpenAI responses") + + mu.Lock() + defer mu.Unlock() + + require.NotNil(t, firstReq, "expected first streaming request to be captured") + if store { + require.NotNil(t, firstReq.Store, "first request should have store field") + require.True(t, *firstReq.Store, "store should be true") + } else if firstReq.Store != nil { + require.False(t, *firstReq.Store, "store should be false") + } + + require.NotNil(t, secondReq, "expected second streaming request to be captured") + foundAssistantReplay := false + for _, item := range secondReq.Prompt { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + role, _ := m["role"].(string) + if role == "assistant" { + foundAssistantReplay = true + } + if store { + require.NotEqual(t, "assistant", role, + "store=true chain-mode prompt should not replay assistant messages") + require.NotEqual(t, "tool", role, + "store=true chain-mode prompt should not replay tool messages") + } + } + + if store { + require.NotNil(t, secondReq.PreviousResponseID, + "store=true follow-up should set previous_response_id") + require.Equal(t, "resp_first_test", *secondReq.PreviousResponseID, + "previous_response_id should match the first response's ID") + } else { + if secondReq.PreviousResponseID != nil { + require.Empty(t, *secondReq.PreviousResponseID, + "store=false follow-up should not set previous_response_id") + } + require.True(t, foundAssistantReplay, + "store=false follow-up should replay prior assistant history") + } +} diff --git a/go.mod b/go.mod index 5023fdd009..c82df7b4c9 100644 --- a/go.mod +++ b/go.mod @@ -493,7 +493,6 @@ require ( github.com/fsnotify/fsnotify v1.9.0 github.com/go-git/go-git/v5 v5.17.0 github.com/mark3labs/mcp-go v0.38.0 - github.com/openai/openai-go/v3 v3.28.0 github.com/shopspring/decimal v1.4.0 gonum.org/v1/gonum v0.17.0 ) @@ -590,6 +589,7 @@ require ( github.com/moby/sys/user v0.4.0 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect github.com/openai/openai-go v1.12.0 // indirect + github.com/openai/openai-go/v3 v3.28.0 // indirect github.com/package-url/packageurl-go v0.1.3 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect