mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
8b1705eb65
## Summary Routes chatd model calls backed by concrete AI Provider rows through the in-process aibridge transport by default, with deployment options to use direct provider routing when AI Gateway is disabled or chat AI Gateway routing is disabled. - Splits model routing into common, direct provider, and AI Gateway paths behind a single deployment-mode entry point. - Builds chatd models through explicit request, route, and options data. Active API key attribution is passed explicitly instead of being hidden inside generic model construction. - For AI Gateway BYOK routes, resolves the user's provider key in chatd, forwards it through provider-specific auth headers, and sets `X-Coder-AI-Governance-Token` to the `delegated` marker so aibridge preserves those headers while still stripping Coder-specific metadata. - Keeps central provider credentials and deployment fallback credentials out of forwarded provider auth headers, so AI Gateway central policy remains authoritative. - Redacts delegated provider auth from default string formatting to avoid accidental plaintext logging of user BYOK credentials. - Covers selected chat models, advisor overrides, title and quickgen paths, subagent overrides, computer use model selection, and an integration-style chat turn through the aibridge transport path. - Persists initiating API key IDs on chat and queued user messages, including subagent child messages, and fails closed for AI Gateway-routed model builds without an active key. - Removes unused `api_key_id` indexes while keeping the persistence columns and foreign keys. - Keeps the deployment option available through config and env parsing, but hides it from CLI help and generated docs. - Stabilizes the subagent poll fallback test so background CreateChat processing cannot win the state transition under slower CI environments. ## Tests - `go test ./coderd/x/chatd -run 'TestAIGatewayProviderAuthForUser|TestAIGatewayProviderAuthRedactsFormatting|TestResolveModelRouteForConfigAIGatewayProviderAuth|TestAIGatewayModelForwardsProviderAuth|TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey|TestAwaitSubagentCompletion' -count=1` - `go test ./coderd/aibridged -run 'TestServeHTTP_DelegatedAPIKey|TestServeHTTP_StripCoderToken' -count=1` - `git diff --check HEAD~1..HEAD` - `make lint` > Mux working on behalf of Mike.
12068 lines
395 KiB
Go
12068 lines
395 KiB
Go
package chatd_test
|
|
|
|
import (
|
|
"cmp"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
mcpgo "github.com/mark3labs/mcp-go/mcp"
|
|
mcpserver "github.com/mark3labs/mcp-go/server"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
|
"github.com/coder/coder/v2/agent/agenttest"
|
|
"github.com/coder/coder/v2/coderd/aibridge"
|
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbfake"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
"github.com/coder/coder/v2/coderd/workspacestats"
|
|
"github.com/coder/coder/v2/coderd/x/chatd"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
|
"github.com/coder/coder/v2/provisioner/echo"
|
|
proto "github.com/coder/coder/v2/provisionersdk/proto"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type recordedOpenAIRequest struct {
|
|
Messages []chattest.OpenAIMessage
|
|
Tools []string
|
|
Store *bool
|
|
PreviousResponseID *string
|
|
ContentLength int64
|
|
}
|
|
|
|
type chatAIGatewayRecordedRequest struct {
|
|
ProviderName string
|
|
Source aibridge.Source
|
|
APIKeyID string
|
|
Path string
|
|
Authorization string
|
|
XAPIKey string
|
|
CoderToken string
|
|
}
|
|
|
|
type chatAIGatewayTestFactory struct {
|
|
target *url.URL
|
|
transport http.RoundTripper
|
|
mu sync.Mutex
|
|
requests []chatAIGatewayRecordedRequest
|
|
}
|
|
|
|
func newChatAIGatewayTestFactory(t testing.TB, targetBaseURL string) *chatAIGatewayTestFactory {
|
|
t.Helper()
|
|
|
|
target, err := url.Parse(targetBaseURL)
|
|
require.NoError(t, err)
|
|
return &chatAIGatewayTestFactory{target: target, transport: http.DefaultTransport}
|
|
}
|
|
|
|
func (f *chatAIGatewayTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
|
|
return chatAIGatewayRoundTripper{factory: f, providerName: providerName, source: source}, nil
|
|
}
|
|
|
|
func (f *chatAIGatewayTestFactory) requestsSnapshot() []chatAIGatewayRecordedRequest {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
return slices.Clone(f.requests)
|
|
}
|
|
|
|
type chatAIGatewayRoundTripper struct {
|
|
factory *chatAIGatewayTestFactory
|
|
providerName string
|
|
source aibridge.Source
|
|
}
|
|
|
|
func (t chatAIGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
|
t.factory.mu.Lock()
|
|
t.factory.requests = append(t.factory.requests, chatAIGatewayRecordedRequest{
|
|
ProviderName: t.providerName,
|
|
Source: t.source,
|
|
APIKeyID: apiKeyID,
|
|
Path: req.URL.Path,
|
|
Authorization: req.Header.Get("Authorization"),
|
|
XAPIKey: req.Header.Get("X-Api-Key"),
|
|
CoderToken: req.Header.Get(aibridge.HeaderCoderToken),
|
|
})
|
|
t.factory.mu.Unlock()
|
|
|
|
targetURL := *t.factory.target
|
|
targetURL.Path = strings.TrimPrefix(req.URL.Path, "/v1")
|
|
if targetURL.Path == "" {
|
|
targetURL.Path = "/"
|
|
}
|
|
targetURL.RawQuery = req.URL.RawQuery
|
|
|
|
cloned := req.Clone(req.Context())
|
|
cloned.URL = &targetURL
|
|
cloned.Host = t.factory.target.Host
|
|
return t.factory.transport.RoundTrip(cloned)
|
|
}
|
|
|
|
func chatAIGatewayTransportFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
|
|
var ptr atomic.Pointer[aibridge.TransportFactory]
|
|
ptr.Store(&factory)
|
|
return &ptr
|
|
}
|
|
|
|
func directChatRoutingDeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
|
t.Helper()
|
|
|
|
values := coderdtest.DeploymentValues(t)
|
|
require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
|
return values
|
|
}
|
|
|
|
func openAIToolName(tool chattest.OpenAITool) string {
|
|
return cmp.Or(tool.Function.Name, tool.Name, tool.Type)
|
|
}
|
|
|
|
func mustChatLastErrorRawMessage(t testing.TB, payload codersdk.ChatError) pqtype.NullRawMessage {
|
|
t.Helper()
|
|
|
|
encoded, err := json.Marshal(payload)
|
|
require.NoError(t, err)
|
|
return pqtype.NullRawMessage{RawMessage: encoded, Valid: true}
|
|
}
|
|
|
|
func requireChatLastErrorPayload(t testing.TB, raw pqtype.NullRawMessage) codersdk.ChatError {
|
|
t.Helper()
|
|
require.True(t, raw.Valid, "last error should be set")
|
|
|
|
var payload codersdk.ChatError
|
|
require.NoError(t, json.Unmarshal(raw.RawMessage, &payload))
|
|
return payload
|
|
}
|
|
|
|
func chatLastErrorMessage(raw pqtype.NullRawMessage) string {
|
|
if !raw.Valid {
|
|
return ""
|
|
}
|
|
|
|
var payload codersdk.ChatError
|
|
if err := json.Unmarshal(raw.RawMessage, &payload); err == nil && payload.Message != "" {
|
|
return payload.Message
|
|
}
|
|
return string(raw.RawMessage)
|
|
}
|
|
|
|
func recordOpenAIRequest(req *chattest.OpenAIRequest) recordedOpenAIRequest {
|
|
messages := append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
tools := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
tools = append(tools, openAIToolName(tool))
|
|
}
|
|
|
|
var store *bool
|
|
if req.Store != nil {
|
|
value := *req.Store
|
|
store = &value
|
|
}
|
|
|
|
var previousResponseID *string
|
|
if req.PreviousResponseID != nil {
|
|
value := *req.PreviousResponseID
|
|
previousResponseID = &value
|
|
}
|
|
|
|
var contentLength int64
|
|
if req.Request != nil {
|
|
contentLength = req.Request.ContentLength
|
|
}
|
|
|
|
return recordedOpenAIRequest{
|
|
Messages: messages,
|
|
Tools: tools,
|
|
Store: store,
|
|
PreviousResponseID: previousResponseID,
|
|
ContentLength: contentLength,
|
|
}
|
|
}
|
|
|
|
func requestHasSystemSubstring(req recordedOpenAIRequest, want string) bool {
|
|
for _, msg := range req.Messages {
|
|
if msg.Role == "system" && strings.Contains(msg.Content, want) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func newWorkspaceToolTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
agentID uuid.UUID,
|
|
planContent string,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, path string, _, _ int64) (io.ReadCloser, string, error) {
|
|
if path == "/home/coder/PLAN.md" {
|
|
return io.NopCloser(strings.NewReader(planContent)), "", nil
|
|
}
|
|
return io.NopCloser(strings.NewReader("")), "", nil
|
|
}).AnyTimes()
|
|
|
|
return newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, gotAgentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, agentID, gotAgentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replicaA := newTestServer(t, db, ps, uuid.New())
|
|
replicaB := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-me",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
runningWorker := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
updated := replicaA.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
|
return event.Status.Status == codersdk.ChatStatusWaiting
|
|
}
|
|
t.Logf("skipping unexpected event: type=%s", event.Type)
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
// Track tools sent in LLM requests. The first call is for the
|
|
// root chat which spawns a subagent; the second call is for the
|
|
// subagent itself.
|
|
var toolsMu sync.Mutex
|
|
toolsByCall := make([][]string, 0, 2)
|
|
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
toolsByCall = append(toolsByCall, names)
|
|
toolsMu.Unlock()
|
|
|
|
if callCount.Add(1) == 1 {
|
|
// Root chat: model calls spawn_agent.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("spawn_agent", `{"type":"general","prompt":"do the thing","title":"sub"}`),
|
|
)
|
|
}
|
|
// Subsequent calls (including the subagent): just reply.
|
|
// Include literal \u0000 in the response text, which is
|
|
// what a real LLM writes when explaining binary output.
|
|
// json.Marshal encodes the backslash as \\, producing
|
|
// \\u0000 in the JSON bytes. The sanitizer must not
|
|
// corrupt this into invalid JSON.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
// Create a root chat whose first model call will spawn a subagent.
|
|
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Spawn a subagent to do the thing.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the root chat AND the subagent to finish.
|
|
// The root chat finishes first, then the chatd server
|
|
// picks up and runs the child (subagent) chat.
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := expClient.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Status != codersdk.ChatStatusWaiting && got.Status != codersdk.ChatStatusError {
|
|
return false
|
|
}
|
|
// Also ensure the subagent LLM call has been made.
|
|
toolsMu.Lock()
|
|
n := len(toolsByCall)
|
|
toolsMu.Unlock()
|
|
// Expect at least 3 calls: root-1 (spawn_agent), child-1, root-2.
|
|
return n >= 3
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
// There should be at least two streamed calls: one for the root
|
|
// chat and one for the subagent child chat.
|
|
toolsMu.Lock()
|
|
recorded := append([][]string(nil), toolsByCall...)
|
|
toolsMu.Unlock()
|
|
|
|
require.GreaterOrEqual(t, len(recorded), 2,
|
|
"expected at least 2 streamed LLM calls (root + subagent)")
|
|
|
|
workspaceTools := []string{
|
|
"list_templates", "read_template", "create_workspace",
|
|
"start_workspace", "stop_workspace",
|
|
}
|
|
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
|
|
|
|
// Identify root and subagent calls. Root chat calls include
|
|
// spawn_agent; the subagent call does not. Because the root chat
|
|
// makes multiple LLM calls (before and after spawn_agent), we
|
|
// find exactly one call that lacks spawn_agent. That's the
|
|
// subagent.
|
|
var rootCalls, childCalls [][]string
|
|
for _, tools := range recorded {
|
|
hasSpawnAgent := slice.Contains(tools, "spawn_agent")
|
|
if hasSpawnAgent {
|
|
rootCalls = append(rootCalls, tools)
|
|
} else {
|
|
childCalls = append(childCalls, tools)
|
|
}
|
|
}
|
|
|
|
require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call")
|
|
require.NotEmpty(t, childCalls, "expected at least one subagent LLM call")
|
|
|
|
// Root chat calls must include workspace and subagent tools.
|
|
for _, tool := range workspaceTools {
|
|
require.Contains(t, rootCalls[0], tool,
|
|
"root chat should have workspace tool %q", tool)
|
|
}
|
|
for _, tool := range subagentTools {
|
|
require.Contains(t, rootCalls[0], tool,
|
|
"root chat should have subagent tool %q", tool)
|
|
}
|
|
|
|
// Standard turns (no turn mode) hide plan-only tools until
|
|
// plan mode.
|
|
require.NotContains(t, rootCalls[0], "ask_user_question",
|
|
"standard-turn root chat should NOT have ask_user_question")
|
|
require.NotContains(t, rootCalls[0], "propose_plan",
|
|
"standard-turn root chat should NOT have propose_plan")
|
|
|
|
// Subagent calls must NOT include workspace or subagent tools.
|
|
for _, tool := range workspaceTools {
|
|
require.NotContains(t, childCalls[0], tool,
|
|
"subagent chat should NOT have workspace tool %q", tool)
|
|
}
|
|
for _, tool := range subagentTools {
|
|
require.NotContains(t, childCalls[0], tool,
|
|
"subagent chat should NOT have subagent tool %q", tool)
|
|
}
|
|
require.NotContains(t, childCalls[0], "ask_user_question",
|
|
"subagent chat should NOT have ask_user_question")
|
|
}
|
|
|
|
func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
// Start an external MCP server whose tools should remain available to the
|
|
// root plan-mode chat but stay hidden from plan-mode subagents.
|
|
mcpSrv := mcpserver.NewMCPServer("plan-root-mcp", "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(mcpSrv))
|
|
t.Cleanup(mcpTS.Close)
|
|
|
|
mcpConfig, err := client.CreateMCPServerConfig(ctx, codersdk.CreateMCPServerConfigRequest{
|
|
DisplayName: "Plan Root MCP",
|
|
Slug: "plan-root-mcp",
|
|
Transport: "streamable_http",
|
|
URL: mcpTS.URL,
|
|
AuthType: "none",
|
|
Availability: "default_off",
|
|
Enabled: true,
|
|
AllowInPlanMode: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var toolsMu sync.Mutex
|
|
toolsByCall := make([][]string, 0, 2)
|
|
requestsByCall := make([]recordedOpenAIRequest, 0, 2)
|
|
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
toolsByCall = append(toolsByCall, names)
|
|
requestsByCall = append(requestsByCall, recordOpenAIRequest(req))
|
|
toolsMu.Unlock()
|
|
|
|
if callCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("spawn_agent", `{"type":"general","prompt":"inspect the codebase","title":"sub"}`),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
PlanMode: codersdk.ChatPlanModePlan,
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Spawn a subagent to inspect the codebase.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := expClient.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Status != codersdk.ChatStatusWaiting && got.Status != codersdk.ChatStatusError {
|
|
return false
|
|
}
|
|
toolsMu.Lock()
|
|
n := len(toolsByCall)
|
|
toolsMu.Unlock()
|
|
return n >= 3
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
recorded := append([][]string(nil), toolsByCall...)
|
|
recordedRequests := append([]recordedOpenAIRequest(nil), requestsByCall...)
|
|
toolsMu.Unlock()
|
|
|
|
require.GreaterOrEqual(t, len(recorded), 2,
|
|
"expected at least 2 streamed LLM calls (root + subagent)")
|
|
require.Len(t, recordedRequests, len(recorded))
|
|
|
|
var rootCalls, childCalls [][]string
|
|
var rootRequests, childRequests []recordedOpenAIRequest
|
|
for i, tools := range recorded {
|
|
if slice.Contains(tools, "spawn_agent") {
|
|
rootCalls = append(rootCalls, tools)
|
|
rootRequests = append(rootRequests, recordedRequests[i])
|
|
continue
|
|
}
|
|
childCalls = append(childCalls, tools)
|
|
childRequests = append(childRequests, recordedRequests[i])
|
|
}
|
|
|
|
require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call")
|
|
require.NotEmpty(t, childCalls, "expected at least one subagent LLM call")
|
|
require.NotEmpty(t, rootRequests, "expected at least one root prompt")
|
|
require.NotEmpty(t, childRequests, "expected at least one subagent prompt")
|
|
require.Contains(t, rootCalls[0], "ask_user_question",
|
|
"root plan-mode chat should have ask_user_question")
|
|
require.Contains(t, rootCalls[0], "write_file",
|
|
"root plan-mode chat should have write_file")
|
|
require.Contains(t, rootCalls[0], "edit_files",
|
|
"root plan-mode chat should have edit_files")
|
|
require.Contains(t, rootCalls[0], "execute",
|
|
"root plan-mode chat should have execute")
|
|
require.Contains(t, rootCalls[0], "process_output",
|
|
"root plan-mode chat should have process_output")
|
|
require.Contains(t, rootCalls[0], "plan-root-mcp__echo",
|
|
"root plan-mode chat should have approved external MCP tools")
|
|
require.NotContains(t, childCalls[0], "ask_user_question",
|
|
"plan-mode subagent should NOT have ask_user_question")
|
|
require.NotContains(t, childCalls[0], "write_file",
|
|
"plan-mode subagent should NOT have write_file")
|
|
require.NotContains(t, childCalls[0], "edit_files",
|
|
"plan-mode subagent should NOT have edit_files")
|
|
require.Contains(t, childCalls[0], "execute",
|
|
"plan-mode subagent should have execute")
|
|
require.Contains(t, childCalls[0], "process_output",
|
|
"plan-mode subagent should have process_output")
|
|
require.NotContains(t, childCalls[0], "plan-root-mcp__echo",
|
|
"plan-mode subagent should NOT have external MCP tools")
|
|
require.True(t, requestHasSystemSubstring(rootRequests[0], "You are in Plan Mode."))
|
|
require.True(t, requestHasSystemSubstring(childRequests[0], "You are in Plan Mode as a delegated sub-agent."))
|
|
require.False(t, requestHasSystemSubstring(childRequests[0], "When the plan is ready, call propose_plan"))
|
|
}
|
|
|
|
func TestExploreSubagentIsReadOnly(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
workspace := coderdtest.CreateWorkspace(t, client, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) {
|
|
cwr.AutomaticUpdates = codersdk.AutomaticUpdatesNever
|
|
})
|
|
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
|
|
|
|
var toolsMu sync.Mutex
|
|
toolsByCall := make([][]string, 0, 2)
|
|
requestsByCall := make([]recordedOpenAIRequest, 0, 2)
|
|
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
toolsByCall = append(toolsByCall, names)
|
|
requestsByCall = append(requestsByCall, recordOpenAIRequest(req))
|
|
toolsMu.Unlock()
|
|
|
|
if callCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("spawn_agent", `{"type":"explore","prompt":"investigate the codebase","title":"sub"}`),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
_, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
WorkspaceID: &workspace.ID,
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Spawn an Explore subagent to inspect the codebase.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
toolsMu.Lock()
|
|
defer toolsMu.Unlock()
|
|
|
|
sawRoot := false
|
|
sawChild := false
|
|
for _, tools := range toolsByCall {
|
|
if slice.Contains(tools, "spawn_agent") {
|
|
sawRoot = true
|
|
continue
|
|
}
|
|
sawChild = true
|
|
}
|
|
return sawRoot && sawChild
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
recorded := append([][]string(nil), toolsByCall...)
|
|
recordedRequests := append([]recordedOpenAIRequest(nil), requestsByCall...)
|
|
toolsMu.Unlock()
|
|
|
|
require.GreaterOrEqual(t, len(recorded), 2,
|
|
"expected at least 2 streamed LLM calls (root + subagent)")
|
|
require.Len(t, recordedRequests, len(recorded))
|
|
|
|
var rootCalls, childCalls [][]string
|
|
var rootRequests, childRequests []recordedOpenAIRequest
|
|
for i, tools := range recorded {
|
|
if slice.Contains(tools, "spawn_agent") {
|
|
rootCalls = append(rootCalls, tools)
|
|
rootRequests = append(rootRequests, recordedRequests[i])
|
|
continue
|
|
}
|
|
childCalls = append(childCalls, tools)
|
|
childRequests = append(childRequests, recordedRequests[i])
|
|
}
|
|
|
|
require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call")
|
|
require.NotEmpty(t, childCalls, "expected at least one subagent LLM call")
|
|
require.NotEmpty(t, rootRequests, "expected at least one root prompt")
|
|
require.NotEmpty(t, childRequests, "expected at least one subagent prompt")
|
|
require.Contains(t, rootCalls[0], "spawn_agent")
|
|
require.Contains(t, rootCalls[0], "write_file")
|
|
require.Contains(t, rootCalls[0], "edit_files")
|
|
require.NotContains(t, childCalls[0], "write_file")
|
|
require.NotContains(t, childCalls[0], "edit_files")
|
|
require.NotContains(t, childCalls[0], "spawn_agent")
|
|
require.NotContains(t, childCalls[0], "wait_agent")
|
|
require.Contains(t, childCalls[0], "read_file")
|
|
require.Contains(t, childCalls[0], "execute")
|
|
require.Contains(t, childCalls[0], "process_output")
|
|
require.True(t, requestHasSystemSubstring(childRequests[0], "You are in Explore Mode as a delegated sub-agent."))
|
|
require.False(t, requestHasSystemSubstring(rootRequests[0], "You are in Explore Mode as a delegated sub-agent."))
|
|
|
|
rootChats, err := db.GetChats(dbauthz.AsChatd(ctx), database.GetChatsParams{
|
|
OwnedOnly: true,
|
|
ViewerID: user.UserID,
|
|
})
|
|
require.NoError(t, err)
|
|
rootIDs := make([]uuid.UUID, 0, len(rootChats))
|
|
for _, root := range rootChats {
|
|
rootIDs = append(rootIDs, root.Chat.ID)
|
|
}
|
|
childRows, err := db.GetChildChatsByParentIDs(dbauthz.AsChatd(ctx), database.GetChildChatsByParentIDsParams{
|
|
ParentIds: rootIDs,
|
|
})
|
|
require.NoError(t, err)
|
|
var exploreChildren []database.Chat
|
|
for _, candidate := range childRows {
|
|
if candidate.Chat.Mode.Valid && candidate.Chat.Mode.ChatMode == database.ChatModeExplore {
|
|
exploreChildren = append(exploreChildren, candidate.Chat)
|
|
}
|
|
}
|
|
require.Len(t, exploreChildren, 1)
|
|
}
|
|
|
|
func TestExploreChatUsesPersistedMCPSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
externalMCP := mcpserver.NewMCPServer("external-snapshot-mcp", "1.0.0")
|
|
externalMCP.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
externalMCPServer := httptest.NewServer(mcpserver.NewStreamableHTTPServer(externalMCP))
|
|
defer externalMCPServer.Close()
|
|
|
|
secondMCP := mcpserver.NewMCPServer("second-mcp", "1.0.0")
|
|
secondMCP.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
secondMCPServer := httptest.NewServer(mcpserver.NewStreamableHTTPServer(secondMCP))
|
|
defer secondMCPServer.Close()
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, _ := seedChatDependenciesWithProvider(t, db, "openai", openAIURL)
|
|
webSearchEnabled := true
|
|
storeEnabled := true
|
|
// OpenAI only serializes web_search through the Responses API.
|
|
// Store=true routes there only for supported Responses models.
|
|
webSearchModel := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai",
|
|
"gpt-4o",
|
|
codersdk.ChatModelCallConfig{
|
|
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
|
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
|
Store: &storeEnabled,
|
|
WebSearchEnabled: &webSearchEnabled,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "External Snapshot MCP",
|
|
Slug: "external-snapshot-mcp",
|
|
Url: externalMCPServer.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Second MCP",
|
|
Slug: "second-mcp",
|
|
Url: secondMCPServer.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
rootChat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true},
|
|
LastModelConfigID: webSearchModel.ID,
|
|
Title: "root",
|
|
ClientType: database.ChatClientTypeApi,
|
|
})
|
|
|
|
exploreChat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true},
|
|
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
|
LastModelConfigID: webSearchModel.ID,
|
|
Title: "explore",
|
|
Mode: database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
},
|
|
Status: database.ChatStatusPending,
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
ClientType: database.ChatClientTypeApi,
|
|
})
|
|
|
|
dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: exploreChat.ID,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: webSearchModel.ID, Valid: true},
|
|
Role: database.ChatMessageRoleUser,
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`[{"type":"text","text":"inspect the codebase"}]`),
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
workspaceToolName := "workspace-snapshot-mcp__echo"
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-snapshot-mcp",
|
|
Name: workspaceToolName,
|
|
Description: "Workspace echo tool",
|
|
Schema: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
}}}, nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
_ = server
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, exploreChat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "explore chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.Len(t, recorded, 1)
|
|
|
|
tools := recorded[0].Tools
|
|
require.Contains(t, tools, "read_file")
|
|
require.Contains(t, tools, "execute")
|
|
require.Contains(t, tools, "process_output")
|
|
require.Contains(t, tools, "external-snapshot-mcp__echo")
|
|
require.Contains(t, tools, "web_search", "Explore provider tool filter should let web_search through when the current model supports it")
|
|
require.NotContains(t, tools, "second-mcp__echo")
|
|
require.NotContains(t, tools, workspaceToolName)
|
|
require.NotContains(t, tools, "write_file")
|
|
require.NotContains(t, tools, "edit_files")
|
|
require.NotContains(t, tools, "spawn_agent")
|
|
}
|
|
|
|
func TestRootExploreChatStaysBuiltinOnlyAtRuntime(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
externalMCP := mcpserver.NewMCPServer("root-explore-runtime-mcp", "1.0.0")
|
|
externalMCP.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
externalMCPServer := httptest.NewServer(mcpserver.NewStreamableHTTPServer(externalMCP))
|
|
defer externalMCPServer.Close()
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Root Explore Runtime MCP",
|
|
Slug: "root-explore-runtime-mcp",
|
|
Url: externalMCPServer.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
exploreChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "root-explore-builtin-only",
|
|
ModelConfigID: model.ID,
|
|
ChatMode: database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
},
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Inspect the codebase."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, exploreChat.ID, server)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, exploreChat.ID)
|
|
require.NoError(t, err)
|
|
if storedChat.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "explore chat failed", "last_error=%q", chatLastErrorMessage(storedChat.LastError))
|
|
}
|
|
require.Equal(t, database.ChatStatusWaiting, storedChat.Status)
|
|
require.ElementsMatch(t, []uuid.UUID{mcpConfig.ID}, storedChat.MCPServerIDs)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.Len(t, recorded, 1)
|
|
|
|
tools := recorded[0].Tools
|
|
require.Contains(t, tools, "read_file")
|
|
require.Contains(t, tools, "execute")
|
|
require.NotContains(t, tools, "write_file")
|
|
require.NotContains(t, tools, "root-explore-runtime-mcp__echo",
|
|
"root Explore chats should strip persisted external MCP tools at runtime")
|
|
}
|
|
|
|
func TestRootExploreChatExcludesWebSearchProviderToolAtRuntime(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, _ := seedChatDependenciesWithProvider(t, db, "openai", openAIURL)
|
|
webSearchEnabled := true
|
|
storeEnabled := true
|
|
// OpenAI only serializes web_search through the Responses API.
|
|
// Store=true routes there only for supported Responses models.
|
|
webSearchModel := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai",
|
|
"gpt-4o",
|
|
codersdk.ChatModelCallConfig{
|
|
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
|
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
|
Store: &storeEnabled,
|
|
WebSearchEnabled: &webSearchEnabled,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
exploreChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "root-explore-no-provider-web-search",
|
|
ModelConfigID: webSearchModel.ID,
|
|
ChatMode: database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Inspect the codebase."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, exploreChat.ID, server)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, exploreChat.ID)
|
|
require.NoError(t, err)
|
|
if storedChat.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "explore chat failed", "last_error=%q", chatLastErrorMessage(storedChat.LastError))
|
|
}
|
|
require.Equal(t, database.ChatStatusWaiting, storedChat.Status)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.Len(t, recorded, 1)
|
|
|
|
tools := recorded[0].Tools
|
|
require.Contains(t, tools, "read_file")
|
|
require.Contains(t, tools, "execute")
|
|
require.NotContains(t, tools, "web_search",
|
|
"root Explore chats should stay builtin-only and must not inherit provider-native web_search at runtime")
|
|
require.NotContains(t, tools, "write_file")
|
|
}
|
|
|
|
func TestExploreChatSendMessageCannotMutateMCPSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
// TODO(CODAGT-353): Re-enable this test after the chatd notification flow
|
|
// refactor gives workers enough causal information to distinguish stale
|
|
// control NOTIFY messages from real interrupts. The current design reuses
|
|
// the same status notification shape for wake-only and interrupt intents,
|
|
// so a stale NOTIFY can cancel a new processChat run.
|
|
t.Skip("skipped until chatd notification flow refactor handles stale control notifications")
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
newEchoMCPServer := func(name string) *httptest.Server {
|
|
t.Helper()
|
|
|
|
mcpSrv := mcpserver.NewMCPServer(name, "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(mcpSrv))
|
|
t.Cleanup(mcpTS.Close)
|
|
return mcpTS
|
|
}
|
|
|
|
parentTS := newEchoMCPServer("runtime-parent-mcp")
|
|
injectedTS := newEchoMCPServer("runtime-injected-mcp")
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
childRequests := func() []recordedOpenAIRequest {
|
|
requestsMu.Lock()
|
|
defer requestsMu.Unlock()
|
|
|
|
filtered := make([]recordedOpenAIRequest, 0, len(requests))
|
|
for _, req := range requests {
|
|
if requestHasSystemSubstring(req, "You are in Explore Mode as a delegated sub-agent.") {
|
|
filtered = append(filtered, req)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
var streamCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
if streamCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("spawn_agent", `{"type":"explore","prompt":"inspect the codebase","title":"sub"}`),
|
|
)
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
parentConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Runtime Parent MCP",
|
|
Slug: "runtime-parent-mcp",
|
|
Url: parentTS.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
injectedConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Runtime Injected MCP",
|
|
Slug: "runtime-injected-mcp",
|
|
Url: injectedTS.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
rootChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "runtime-parent",
|
|
ModelConfigID: model.ID,
|
|
MCPServerIDs: []uuid.UUID{parentConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Spawn an Explore subagent to inspect the codebase."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var exploreChat database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
childRows, err := db.GetChildChatsByParentIDs(dbauthz.AsChatd(ctx), database.GetChildChatsByParentIDsParams{
|
|
ParentIds: []uuid.UUID{rootChat.ID},
|
|
})
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, candidate := range childRows {
|
|
if candidate.Chat.Mode.Valid && candidate.Chat.Mode.ChatMode == database.ChatModeExplore {
|
|
exploreChat = candidate.Chat
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, exploreChat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "explore chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
exploreChat, err = db.GetChatByID(ctx, exploreChat.ID)
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{parentConfig.ID}, exploreChat.MCPServerIDs)
|
|
|
|
initialChildRequestCount := len(childRequests())
|
|
require.GreaterOrEqual(t, initialChildRequestCount, 1)
|
|
|
|
updatedMCPServerIDs := []uuid.UUID{injectedConfig.ID}
|
|
_, err = server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: exploreChat.ID,
|
|
CreatedBy: user.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("inspect the codebase again")},
|
|
MCPServerIDs: &updatedMCPServerIDs,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
storedExploreChat, err := db.GetChatByID(ctx, exploreChat.ID)
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{parentConfig.ID}, storedExploreChat.MCPServerIDs)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
return len(childRequests()) > initialChildRequestCount
|
|
}, testutil.IntervalFast)
|
|
|
|
chatResult = waitForTerminalChat(ctx, t, db, exploreChat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "explore chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
recordedChildRequests := childRequests()
|
|
require.GreaterOrEqual(t, len(recordedChildRequests), initialChildRequestCount+1)
|
|
|
|
tools := recordedChildRequests[len(recordedChildRequests)-1].Tools
|
|
require.Contains(t, tools, "runtime-parent-mcp__echo")
|
|
require.NotContains(t, tools, "runtime-injected-mcp__echo",
|
|
"Explore child runtime should keep the spawn-time MCP snapshot after SendMessage")
|
|
}
|
|
|
|
func TestPlanModeRootChatAllowsApprovedExternalMCPTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
echoMCP := mcpserver.NewMCPServer("plan-visibility-echo", "1.0.0")
|
|
echoMCP.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
echoTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(echoMCP))
|
|
t.Cleanup(echoTS.Close)
|
|
|
|
filteredMCP := mcpserver.NewMCPServer("plan-visibility-filtered", "1.0.0")
|
|
filteredMCP.AddTools(
|
|
mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("visible",
|
|
mcpgo.WithDescription("Visible tool"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("visible: " + input), nil
|
|
},
|
|
},
|
|
mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("hidden",
|
|
mcpgo.WithDescription("Hidden tool"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("hidden: " + input), nil
|
|
},
|
|
},
|
|
)
|
|
filteredTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(filteredMCP))
|
|
t.Cleanup(filteredTS.Close)
|
|
|
|
var (
|
|
requests []recordedOpenAIRequest
|
|
requestsMu sync.Mutex
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Done.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
approvedConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Plan Approved MCP",
|
|
Slug: "plan-approved-mcp",
|
|
Url: echoTS.URL,
|
|
AllowInPlanMode: true,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
blockedConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Plan Blocked MCP",
|
|
Slug: "plan-blocked-mcp",
|
|
Url: echoTS.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
filteredConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Plan Filtered MCP",
|
|
Slug: "plan-filtered-mcp",
|
|
Url: filteredTS.URL,
|
|
AllowInPlanMode: true,
|
|
ToolAllowList: []string{"visible"},
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
workspaceToolName := "workspace-plan-mcp__echo"
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-plan-mcp",
|
|
Name: workspaceToolName,
|
|
Description: "Workspace echo tool",
|
|
Schema: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
}}}, nil).
|
|
Times(1)
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
planChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "plan-mode-root-mcp-visibility",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{approvedConfig.ID, blockedConfig.ID, filteredConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("List the available tools in plan mode."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, planChat.ID, server)
|
|
|
|
planChatResult, err := db.GetChatByID(ctx, planChat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, planChatResult.Status)
|
|
|
|
askChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "ask-mode-root-mcp-visibility",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{approvedConfig.ID, blockedConfig.ID, filteredConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("List the available tools outside plan mode."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, askChat.ID, server)
|
|
|
|
askChatResult, err := db.GetChatByID(ctx, askChat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, askChatResult.Status)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.Len(t, recorded, 2, "expected exactly one streamed model call per chat")
|
|
|
|
planTools := recorded[0].Tools
|
|
askTools := recorded[1].Tools
|
|
|
|
require.Contains(t, planTools, "plan-approved-mcp__echo",
|
|
"root plan mode should expose approved external MCP tools")
|
|
require.NotContains(t, planTools, "plan-blocked-mcp__echo",
|
|
"root plan mode should hide unapproved external MCP tools")
|
|
require.Contains(t, planTools, "plan-filtered-mcp__visible",
|
|
"root plan mode should keep allowlisted tools from approved MCP servers")
|
|
require.NotContains(t, planTools, "plan-filtered-mcp__hidden",
|
|
"root plan mode should still respect MCP tool allowlists")
|
|
require.NotContains(t, planTools, workspaceToolName,
|
|
"root plan mode should exclude workspace MCP tools")
|
|
|
|
require.Contains(t, askTools, "plan-approved-mcp__echo",
|
|
"ask mode should keep approved external MCP tools")
|
|
require.Contains(t, askTools, "plan-blocked-mcp__echo",
|
|
"ask mode should keep unapproved-for-plan external MCP tools")
|
|
require.Contains(t, askTools, "plan-filtered-mcp__visible",
|
|
"ask mode should keep allowlisted tools from external MCP servers")
|
|
require.NotContains(t, askTools, "plan-filtered-mcp__hidden",
|
|
"ask mode should continue respecting MCP tool allowlists")
|
|
require.Contains(t, askTools, workspaceToolName,
|
|
"ask mode should continue exposing workspace MCP tools")
|
|
}
|
|
|
|
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "db-transition",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
updated := replica.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "archive-pending",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
require.False(t, fromDB.StartedAt.Valid)
|
|
require.False(t, fromDB.HeartbeatAt.Valid)
|
|
require.True(t, fromDB.Archived)
|
|
require.Zero(t, fromDB.PinOrder)
|
|
}
|
|
|
|
// TestUnarchiveChildChat covers the deterministic branches of the
|
|
// Server.UnarchiveChat child path: happy path, archived-parent reject,
|
|
// and already-active no-op.
|
|
func TestUnarchiveChildChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ChildWithActiveParentUnarchives", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
parent, child := insertParentWithArchivedChild(ctx, t, db, user, org, model)
|
|
|
|
require.NoError(t, replica.UnarchiveChat(ctx, child))
|
|
|
|
dbChild, err := db.GetChatByID(ctx, child.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, dbChild.Archived, "child should be unarchived")
|
|
|
|
dbParent, err := db.GetChatByID(ctx, parent.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, dbParent.Archived, "parent should stay active")
|
|
})
|
|
|
|
t.Run("ChildWithArchivedParentRejected", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
parent, child := insertParentWithArchivedChild(ctx, t, db, user, org, model)
|
|
_, err := db.ArchiveChatByID(ctx, parent.ID)
|
|
require.NoError(t, err)
|
|
|
|
err = replica.UnarchiveChat(ctx, child)
|
|
require.ErrorIs(t, err, chatd.ErrChildUnarchiveParentArchived)
|
|
|
|
dbChild, err := db.GetChatByID(ctx, child.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, dbChild.Archived, "child should remain archived")
|
|
})
|
|
|
|
t.Run("AlreadyActiveChildNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
_, child := insertParentWithActiveChild(t, db, user, org, model)
|
|
|
|
require.NoError(t, replica.UnarchiveChat(ctx, child))
|
|
|
|
dbChild, err := db.GetChatByID(ctx, child.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, dbChild.Archived, "child should stay active")
|
|
})
|
|
}
|
|
|
|
// insertParentWithActiveChild creates a parent chat and an active
|
|
// child chat linked to it. Both are returned in their initial
|
|
// (active) state.
|
|
func insertParentWithActiveChild(
|
|
t *testing.T,
|
|
db database.Store,
|
|
user database.User,
|
|
org database.Organization,
|
|
model database.ChatModelConfig,
|
|
) (parent database.Chat, child database.Chat) {
|
|
t.Helper()
|
|
parent = dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: model.ID,
|
|
Title: "parent",
|
|
})
|
|
child = dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: model.ID,
|
|
Title: "child",
|
|
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
})
|
|
return parent, child
|
|
}
|
|
|
|
// insertParentWithArchivedChild creates an active parent and an
|
|
// individually-archived child. The returned child reflects its
|
|
// current (archived) state in the DB.
|
|
func insertParentWithArchivedChild(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
user database.User,
|
|
org database.Organization,
|
|
model database.ChatModelConfig,
|
|
) (parent database.Chat, child database.Chat) {
|
|
t.Helper()
|
|
parent, child = insertParentWithActiveChild(t, db, user, org, model)
|
|
_, err := db.ArchiveChatByID(ctx, child.ID)
|
|
require.NoError(t, err)
|
|
child, err = db.GetChatByID(ctx, child.ID)
|
|
require.NoError(t, err)
|
|
return parent, child
|
|
}
|
|
|
|
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
streamStarted := make(chan struct{})
|
|
streamCanceled := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
select {
|
|
case <-streamCanceled:
|
|
default:
|
|
close(streamCanceled)
|
|
}
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "archive-interrupt",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
_, events, cancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
|
|
err = server.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamCanceled:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
gotWaitingStatus := false
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
for {
|
|
select {
|
|
case ev := <-events:
|
|
if ev.Type == codersdk.ChatStreamEventTypeStatus &&
|
|
ev.Status != nil &&
|
|
ev.Status.Status == codersdk.ChatStatusWaiting {
|
|
gotWaitingStatus = true
|
|
return true
|
|
}
|
|
default:
|
|
return gotWaitingStatus
|
|
}
|
|
}
|
|
}, testutil.IntervalFast)
|
|
require.True(t, gotWaitingStatus, "expected a waiting status event after archive")
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Archived &&
|
|
fromDB.Status == database.ChatStatusWaiting &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.StartedAt.Valid &&
|
|
!fromDB.HeartbeatAt.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queuedMessages, 1)
|
|
require.Equal(t, queuedResult.QueuedMessage.ID, queuedMessages[0].ID)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
userMessages := 0
|
|
for _, msg := range messages {
|
|
if msg.Role == database.ChatMessageRoleUser {
|
|
userMessages++
|
|
}
|
|
}
|
|
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
|
}
|
|
|
|
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "heartbeat-ownership",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wrong worker_id should return no IDs.
|
|
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
|
IDs: []uuid.UUID{chat.ID},
|
|
WorkerID: uuid.New(),
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, ids)
|
|
|
|
// Correct worker_id should return the chat's ID.
|
|
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
|
IDs: []uuid.UUID{chat.ID},
|
|
WorkerID: workerID,
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, ids, 1)
|
|
require.Equal(t, chat.ID, ids[0])
|
|
}
|
|
|
|
func TestCreateChatPersistsAPIKeyIDOnInitialUserMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "create-chat-api-key-id",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
APIKeyID: apiKey.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
require.Equal(t, database.ChatMessageRoleUser, messages[0].Role)
|
|
require.True(t, messages[0].APIKeyID.Valid)
|
|
require.Equal(t, apiKey.ID, messages[0].APIKeyID.String)
|
|
}
|
|
|
|
func TestSendMessagePersistsAPIKeyIDOnUserMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: model.ID,
|
|
Title: "send-message-api-key-id",
|
|
})
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
CreatedBy: user.ID,
|
|
Content: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("message with api key id"),
|
|
},
|
|
APIKeyID: apiKey.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, result.Queued)
|
|
require.True(t, result.Message.APIKeyID.Valid)
|
|
require.Equal(t, apiKey.ID, result.Message.APIKeyID.String)
|
|
|
|
stored, err := db.GetChatMessageByID(ctx, result.Message.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, stored.APIKeyID.Valid)
|
|
require.Equal(t, apiKey.ID, stored.APIKeyID.String)
|
|
}
|
|
|
|
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "queue-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Queued)
|
|
require.NotNil(t, result.QueuedMessage)
|
|
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
|
|
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
|
|
require.True(t, result.Chat.WorkerID.Valid)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 1)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
}
|
|
|
|
func TestPlanTurnPromptContract(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
var (
|
|
requests []recordedOpenAIRequest
|
|
requestsMu sync.Mutex
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("plan acknowledged")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
planModeInstructions := "Ask about deployment sequencing before finalizing the plan."
|
|
err := db.UpsertChatPlanModeInstructions(dbauthz.AsSystemRestricted(ctx), planModeInstructions)
|
|
require.NoError(t, err)
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
server := newWorkspaceToolTestServer(t, db, ps, dbAgent.ID, "# Plan\n")
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "plan-turn-prompt-contract",
|
|
ModelConfigID: model.ID,
|
|
PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true},
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Plan the rollout."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
waitForChatProcessed(ctx, t, db, chat.ID, server)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
|
|
require.Len(t, recorded, 1, "expected exactly 1 streamed model call")
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "You are in Plan Mode."))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "The only intentional authored workspace artifact is the plan file"))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "You may use execute and process_output for exploration"))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "approved external MCP tools when available"))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "Workspace MCP tools are not available in root plan mode"))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], "After a successful propose_plan call, stop immediately"))
|
|
require.True(t, requestHasSystemSubstring(recorded[0], planModeInstructions))
|
|
for _, msg := range recorded[0].Messages {
|
|
if msg.Role != "system" {
|
|
continue
|
|
}
|
|
// The overlay prompt includes a placeholder that is replaced at
|
|
// runtime, so strip only the stable body text before checking.
|
|
overlayBody := strings.TrimSuffix(
|
|
chatd.PlanningOverlayPrompt(),
|
|
"{{CODER_CHAT_PLAN_FILE_PATH_BLOCK}}",
|
|
)
|
|
sanitized := strings.ReplaceAll(msg.Content, overlayBody, "")
|
|
require.NotContains(t, sanitized, "propose_plan")
|
|
}
|
|
}
|
|
|
|
func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "queue-when-waiting-with-backlog",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("older queued"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("newer queued")},
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Queued)
|
|
require.NotNil(t, result.QueuedMessage)
|
|
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 2)
|
|
|
|
olderSDK := db2sdk.ChatQueuedMessage(queued[0])
|
|
require.Len(t, olderSDK.Content, 1)
|
|
require.Equal(t, "older queued", olderSDK.Content[0].Text)
|
|
|
|
newerSDK := db2sdk.ChatQueuedMessage(queued[1])
|
|
require.Len(t, newerSDK.Content, 1)
|
|
require.Equal(t, "newer queued", newerSDK.Content[0].Text)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
}
|
|
|
|
func TestSendMessageRejectsInvalidQueuedModelConfigID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelConfig := seedChatDependencies(t, db)
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusPending,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfig.ID,
|
|
Title: "reject invalid queued model config",
|
|
})
|
|
|
|
invalidModelConfigID := uuid.New()
|
|
_, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
ModelConfigID: invalidModelConfigID,
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrInvalidModelConfigID)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queued)
|
|
}
|
|
|
|
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newStartedTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// CreateChat calls signalWake which triggers processOnce in
|
|
// the background. Wait for that processing to finish so it
|
|
// doesn't race with the manual status update below.
|
|
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("interrupt")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The message should be queued, not inserted directly.
|
|
require.True(t, result.Queued)
|
|
require.NotNil(t, result.QueuedMessage)
|
|
|
|
// The chat should transition to waiting (interrupt signal),
|
|
// not pending.
|
|
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
|
|
// The message should be in the queue, not in chat_messages.
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 1)
|
|
|
|
// Only messages from the initial processing round should be in
|
|
// chat_messages (user + assistant). The "interrupt" message must
|
|
// be in the queue, not inserted directly.
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 2)
|
|
}
|
|
|
|
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "edit-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initialMessages, 1)
|
|
editedMessageID := initialMessages[0].ID
|
|
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow-up")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("another")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("queued"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
|
apiKeyID := apiKey.ID
|
|
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMessageID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
APIKeyID: apiKeyID,
|
|
})
|
|
require.NoError(t, err)
|
|
// The edited message is soft-deleted and a new message is inserted,
|
|
// so the returned message ID will differ from the original.
|
|
require.NotEqual(t, editedMessageID, editResult.Message.ID)
|
|
require.True(t, editResult.Message.APIKeyID.Valid)
|
|
require.Equal(t, apiKeyID, editResult.Message.APIKeyID.String)
|
|
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
|
|
require.False(t, editResult.Chat.WorkerID.Valid)
|
|
|
|
editedSDK := db2sdk.ChatMessage(editResult.Message)
|
|
require.Len(t, editedSDK.Content, 1)
|
|
require.Equal(t, "edited", editedSDK.Content[0].Text)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
require.Equal(t, editResult.Message.ID, messages[0].ID)
|
|
require.True(t, messages[0].APIKeyID.Valid)
|
|
require.Equal(t, apiKeyID, messages[0].APIKeyID.String)
|
|
onlyMessage := db2sdk.ChatMessage(messages[0])
|
|
require.Len(t, onlyMessage.Content, 1)
|
|
require.Equal(t, "edited", onlyMessage.Content[0].Text)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 0)
|
|
|
|
// WaitUntilIdleForTest drains the debug-cleanup goroutine
|
|
// from EditMessage. Must be called from the test goroutine
|
|
// (not inside require.Eventually) to avoid Add/Wait race.
|
|
chatd.WaitUntilIdleForTest(replica)
|
|
var chatFromDB database.Chat
|
|
require.Eventually(t, func() bool {
|
|
c, e := db.GetChatByID(ctx, chat.ID)
|
|
if e != nil {
|
|
return false
|
|
}
|
|
chatFromDB = c
|
|
return chatFromDB.Status != database.ChatStatusRunning
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
require.False(t, chatFromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestCreateChatInsertsWorkspaceAwarenessMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("WithWorkspace", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
server := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
|
Title: "test-with-workspace",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
var workspaceMsg *database.ChatMessage
|
|
for _, msg := range messages {
|
|
if msg.Role == database.ChatMessageRoleSystem {
|
|
content := string(msg.Content.RawMessage)
|
|
if strings.Contains(content, "attached to a workspace") {
|
|
workspaceMsg = &msg
|
|
break
|
|
}
|
|
}
|
|
}
|
|
require.NotNil(t, workspaceMsg, "workspace awareness system message should exist")
|
|
require.Equal(t, database.ChatMessageRoleSystem, workspaceMsg.Role)
|
|
require.Equal(t, database.ChatMessageVisibilityModel, workspaceMsg.Visibility)
|
|
})
|
|
|
|
t.Run("WithoutWorkspace", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
server := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "test-without-workspace",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
var workspaceMsg *database.ChatMessage
|
|
for _, msg := range messages {
|
|
if msg.Role == database.ChatMessageRoleSystem {
|
|
content := string(msg.Content.RawMessage)
|
|
if strings.Contains(content, "No workspace is attached to this chat yet") {
|
|
workspaceMsg = &msg
|
|
break
|
|
}
|
|
}
|
|
}
|
|
require.NotNil(t, workspaceMsg, "workspace awareness system message should exist")
|
|
require.Equal(t, database.ChatMessageRoleSystem, workspaceMsg.Role)
|
|
require.Equal(t, database.ChatMessageVisibilityModel, workspaceMsg.Visibility)
|
|
workspaceContent := string(workspaceMsg.Content.RawMessage)
|
|
require.Contains(t, workspaceContent, "Do not create or start a workspace by default")
|
|
require.Contains(t, workspaceContent, "Only call create_workspace or start_workspace")
|
|
require.NotContains(t, workspaceContent, "Create one using the create_workspace tool before using workspace tools")
|
|
})
|
|
}
|
|
|
|
func TestCreateChatRejectsWhenUsageLimitReached(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
|
Enabled: true,
|
|
DefaultLimitMicros: 100,
|
|
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
existingChat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "existing-limit-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("assistant"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: existingChat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: assistantContent,
|
|
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
|
})
|
|
|
|
beforeChats, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnedOnly: true,
|
|
ViewerID: user.ID,
|
|
AfterID: uuid.Nil,
|
|
OffsetOpt: 0,
|
|
LimitOpt: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, beforeChats, 1)
|
|
|
|
_, err = replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "over-limit",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.Error(t, err)
|
|
|
|
var limitErr *chatd.UsageLimitExceededError
|
|
require.ErrorAs(t, err, &limitErr)
|
|
require.Equal(t, int64(100), limitErr.LimitMicros)
|
|
require.Equal(t, int64(100), limitErr.ConsumedMicros)
|
|
|
|
afterChats, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnedOnly: true,
|
|
ViewerID: user.ID,
|
|
AfterID: uuid.Nil,
|
|
OffsetOpt: 0,
|
|
LimitOpt: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, afterChats, len(beforeChats))
|
|
}
|
|
|
|
func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newStartedTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
|
|
|
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
|
Enabled: true,
|
|
DefaultLimitMicros: 100,
|
|
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "queued-limit-reached",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// CreateChat calls signalWake which triggers processOnce in
|
|
// the background. Wait for that processing to finish so it
|
|
// doesn't race with the manual status update below.
|
|
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
APIKeyID: apiKey.ID,
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
require.True(t, queuedResult.QueuedMessage.APIKeyID.Valid)
|
|
require.Equal(t, apiKey.ID, queuedResult.QueuedMessage.APIKeyID.String)
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("assistant"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: assistantContent,
|
|
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
|
})
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
|
|
require.True(t, result.PromotedMessage.APIKeyID.Valid)
|
|
require.Equal(t, apiKey.ID, result.PromotedMessage.APIKeyID.String)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queued)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 4)
|
|
require.Equal(t, database.ChatMessageRoleUser, messages[3].Role)
|
|
}
|
|
|
|
func TestPromoteQueuedMessageUsesQueuedModelConfigID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelConfigA := seedChatDependencies(t, db)
|
|
modelConfigB := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai",
|
|
"gpt-4o-mini-promote-"+uuid.NewString(),
|
|
codersdk.ChatModelCallConfig{},
|
|
)
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfigA.ID,
|
|
Title: "promote queued uses stored model",
|
|
})
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with model b")})
|
|
require.NoError(t, err)
|
|
queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
ModelConfigID: uuid.NullUUID{
|
|
UUID: modelConfigB.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.PromotedMessage.ModelConfigID.Valid)
|
|
require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID)
|
|
// The processor can pick up the pending chat immediately after
|
|
// promotion, so this test only requires that promotion moved it out of
|
|
// waiting and preserved the queued model configuration.
|
|
require.Contains(t, []database.ChatStatus{
|
|
database.ChatStatusPending,
|
|
database.ChatStatusRunning,
|
|
}, storedChat.Status)
|
|
}
|
|
|
|
func TestPromoteQueuedMessageReloadsChatWhenModelConfigChangesDuringPending(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelConfigA := seedChatDependencies(t, db)
|
|
modelConfigB := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai",
|
|
"gpt-4o-mini-promote-pending-"+uuid.NewString(),
|
|
codersdk.ChatModelCallConfig{},
|
|
)
|
|
|
|
watchEvents := make(chan struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}, 1)
|
|
cancelWatch, err := ps.SubscribeWithErr(
|
|
coderdpubsub.ChatWatchEventChannel(user.ID),
|
|
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
select {
|
|
case watchEvents <- struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}{payload: payload, err: err}:
|
|
default:
|
|
}
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer cancelWatch()
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusPending,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfigA.ID,
|
|
Title: "promote queued reloads pending chat",
|
|
})
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with new model")})
|
|
require.NoError(t, err)
|
|
queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
ModelConfigID: uuid.NullUUID{
|
|
UUID: modelConfigB.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.PromotedMessage.ModelConfigID.Valid)
|
|
require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, storedChat.Status)
|
|
require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID)
|
|
|
|
select {
|
|
case event := <-watchEvents:
|
|
require.NoError(t, event.err)
|
|
require.Equal(t, codersdk.ChatWatchEventKindStatusChange, event.payload.Kind)
|
|
require.Equal(t, chat.ID, event.payload.Chat.ID)
|
|
require.Equal(t, codersdk.ChatStatusPending, event.payload.Chat.Status)
|
|
require.Equal(t, modelConfigB.ID, event.payload.Chat.LastModelConfigID)
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for status change watch event")
|
|
}
|
|
}
|
|
|
|
func TestAutoPromoteQueuedMessagesPreservesPerTurnModelOrder(t *testing.T) {
|
|
t.Parallel()
|
|
// TODO(CODAGT-353): Re-enable this test after the chatd notification flow
|
|
// refactor gives workers enough causal information to distinguish stale
|
|
// control NOTIFY messages from real interrupts. The current design reuses
|
|
// the same status notification shape for wake-only and interrupt intents,
|
|
// so a stale NOTIFY can cancel a new processChat run.
|
|
t.Skip("skipped until chatd notification flow refactor handles stale control notifications")
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
|
|
firstRunStarted := make(chan struct{})
|
|
secondRunStarted := make(chan struct{}, 1)
|
|
thirdRunStarted := make(chan struct{}, 1)
|
|
allowFirstRunFinish := make(chan struct{})
|
|
var requestCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch requestCount.Add(1) {
|
|
case 1:
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("first run partial")[0]
|
|
select {
|
|
case <-firstRunStarted:
|
|
default:
|
|
close(firstRunStarted)
|
|
}
|
|
<-allowFirstRunFinish
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
case 2:
|
|
select {
|
|
case secondRunStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("second run done")...)
|
|
case 3:
|
|
select {
|
|
case thirdRunStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("third run done")...)
|
|
default:
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("extra run done")...)
|
|
}
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
// Disable periodic polling so chained promotions must be driven by
|
|
// signalWake.
|
|
cfg.PendingChatAcquireInterval = time.Hour
|
|
})
|
|
user, org, modelConfigA := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
modelConfigB := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai-compat",
|
|
"gpt-4o-mini-queue-b-"+uuid.NewString(),
|
|
codersdk.ChatModelCallConfig{},
|
|
)
|
|
modelConfigC := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai-compat",
|
|
"gpt-4o-mini-queue-c-"+uuid.NewString(),
|
|
codersdk.ChatModelCallConfig{},
|
|
)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "auto-promote per-turn model order",
|
|
ModelConfigID: modelConfigA.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.TryReceive(ctx, t, firstRunStarted)
|
|
|
|
queuedB, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued b")},
|
|
ModelConfigID: modelConfigB.ID,
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedB.Queued)
|
|
|
|
queuedC, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued c")},
|
|
ModelConfigID: modelConfigC.ID,
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedC.Queued)
|
|
|
|
close(allowFirstRunFinish)
|
|
|
|
testutil.TryReceive(ctx, t, secondRunStarted)
|
|
testutil.TryReceive(ctx, t, thirdRunStarted)
|
|
require.GreaterOrEqual(t, requestCount.Load(), int32(3))
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queuedMessages)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, storedChat.Status)
|
|
require.Equal(t, modelConfigC.ID, storedChat.LastModelConfigID)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var userTexts []string
|
|
var userModelConfigIDs []uuid.UUID
|
|
for _, message := range messages {
|
|
if message.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
sdkMessage := db2sdk.ChatMessage(message)
|
|
require.Len(t, sdkMessage.Content, 1)
|
|
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
|
require.True(t, message.ModelConfigID.Valid)
|
|
userModelConfigIDs = append(userModelConfigIDs, message.ModelConfigID.UUID)
|
|
}
|
|
require.Equal(t, []string{"hello", "queued b", "queued c"}, userTexts)
|
|
require.Equal(t, []uuid.UUID{modelConfigA.ID, modelConfigB.ID, modelConfigC.ID}, userModelConfigIDs)
|
|
}
|
|
|
|
func TestAutoPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{})
|
|
}
|
|
|
|
func TestAutoPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
})
|
|
}
|
|
|
|
func testAutoPromoteQueuedMessageFallback(t *testing.T, queuedModelConfigID uuid.NullUUID) {
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
|
|
firstRunStarted := make(chan struct{})
|
|
secondRunStarted := make(chan struct{}, 1)
|
|
allowFirstRunFinish := make(chan struct{})
|
|
var requestCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch requestCount.Add(1) {
|
|
case 1:
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("first run partial")[0]
|
|
select {
|
|
case <-firstRunStarted:
|
|
default:
|
|
close(firstRunStarted)
|
|
}
|
|
<-allowFirstRunFinish
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
default:
|
|
select {
|
|
case secondRunStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("fallback run done")...)
|
|
}
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
// Disable periodic polling so only signalWake can
|
|
// trigger the next processing run.
|
|
cfg.PendingChatAcquireInterval = time.Hour
|
|
})
|
|
user, org, modelConfig := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "auto-promote queued fallback",
|
|
ModelConfigID: modelConfig.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.TryReceive(ctx, t, firstRunStarted)
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
ModelConfigID: queuedModelConfigID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
close(allowFirstRunFinish)
|
|
|
|
testutil.TryReceive(ctx, t, secondRunStarted)
|
|
require.GreaterOrEqual(t, requestCount.Load(), int32(2))
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queuedMessages)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, storedChat.Status)
|
|
require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var found bool
|
|
for _, message := range messages {
|
|
if message.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
sdkMessage := db2sdk.ChatMessage(message)
|
|
require.Len(t, sdkMessage.Content, 1)
|
|
if sdkMessage.Content[0].Text != "legacy queued row" {
|
|
continue
|
|
}
|
|
require.True(t, message.ModelConfigID.Valid)
|
|
require.Equal(t, modelConfig.ID, message.ModelConfigID.UUID)
|
|
found = true
|
|
}
|
|
require.True(t, found)
|
|
}
|
|
|
|
func TestPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelConfigA := seedChatDependencies(t, db)
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfigA.ID,
|
|
Title: "promote queued legacy fallback",
|
|
})
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")})
|
|
require.NoError(t, err)
|
|
queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.PromotedMessage.ModelConfigID.Valid)
|
|
require.Equal(t, modelConfigA.ID, result.PromotedMessage.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigA.ID, storedChat.LastModelConfigID)
|
|
}
|
|
|
|
func TestPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelConfig := seedChatDependencies(t, db)
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfig.ID,
|
|
Title: "promote queued invalid fallback",
|
|
})
|
|
|
|
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("invalid queued model")})
|
|
require.NoError(t, err)
|
|
queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent,
|
|
ModelConfigID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.PromotedMessage.ModelConfigID.Valid)
|
|
require.Equal(t, modelConfig.ID, result.PromotedMessage.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID)
|
|
}
|
|
|
|
func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
|
Enabled: true,
|
|
DefaultLimitMicros: 100,
|
|
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
clock := quartz.NewMock(t)
|
|
|
|
streamStarted := make(chan struct{})
|
|
interrupted := make(chan struct{})
|
|
secondRequestStarted := make(chan struct{}, 1)
|
|
thirdRequestStarted := make(chan struct{}, 1)
|
|
allowFinish := make(chan struct{})
|
|
allowSecondRequestFinish := make(chan struct{})
|
|
allowThirdRequestFinish := make(chan struct{})
|
|
var requestCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch requestCount.Add(1) {
|
|
case 1:
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
select {
|
|
case <-interrupted:
|
|
default:
|
|
close(interrupted)
|
|
}
|
|
<-allowFinish
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
case 2:
|
|
select {
|
|
case secondRequestStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("second run partial")[0]
|
|
select {
|
|
case <-allowSecondRequestFinish:
|
|
case <-req.Context().Done():
|
|
}
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
case 3:
|
|
select {
|
|
case thirdRequestStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("third run partial")[0]
|
|
select {
|
|
case <-allowThirdRequestFinish:
|
|
case <-req.Context().Done():
|
|
}
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.Clock = clock
|
|
// Keep periodic polling frozen so request handoff is synchronized
|
|
// through explicit mock channels.
|
|
cfg.PendingChatAcquireInterval = time.Hour
|
|
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-autopromote-limit",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.TryReceive(ctx, t, streamStarted)
|
|
|
|
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
|
|
testutil.TryReceive(ctx, t, interrupted)
|
|
|
|
close(allowFinish)
|
|
testutil.TryReceive(ctx, t, secondRequestStarted)
|
|
|
|
laterQueuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("later queued")},
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, laterQueuedResult.Queued)
|
|
require.NotNil(t, laterQueuedResult.QueuedMessage)
|
|
|
|
spendChat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: model.ID,
|
|
Title: "other-spend",
|
|
})
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("spent elsewhere"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: spendChat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: assistantContent,
|
|
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
|
})
|
|
|
|
close(allowSecondRequestFinish)
|
|
testutil.TryReceive(ctx, t, thirdRequestStarted)
|
|
require.GreaterOrEqual(t, requestCount.Load(), int32(3))
|
|
|
|
close(allowThirdRequestFinish)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queued)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
userTexts := make([]string, 0, 3)
|
|
for _, message := range messages {
|
|
if message.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
sdkMessage := db2sdk.ChatMessage(message)
|
|
if len(sdkMessage.Content) != 1 {
|
|
continue
|
|
}
|
|
userTexts = append(userTexts, sdkMessage.Content[0].Text)
|
|
}
|
|
require.Equal(t, []string{"hello", "queued", "later queued"}, userTexts)
|
|
}
|
|
|
|
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
|
Enabled: true,
|
|
DefaultLimitMicros: 100,
|
|
Period: string(codersdk.ChatUsageLimitPeriodDay),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "edit-limit-reached",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
editedMessageID := messages[0].ID
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("assistant"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: assistantContent,
|
|
TotalCostMicros: sql.NullInt64{Int64: 100, Valid: true},
|
|
})
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMessageID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.Error(t, err)
|
|
|
|
var limitErr *chatd.UsageLimitExceededError
|
|
require.ErrorAs(t, err, &limitErr)
|
|
require.Equal(t, int64(100), limitErr.LimitMicros)
|
|
require.Equal(t, int64(100), limitErr.ConsumedMicros)
|
|
|
|
messages, err = db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 2)
|
|
originalMessage := db2sdk.ChatMessage(messages[0])
|
|
require.Len(t, originalMessage.Content, 1)
|
|
require.Equal(t, "original", originalMessage.Content[0].Text)
|
|
}
|
|
|
|
func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "missing-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: 999999,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
|
|
}
|
|
|
|
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "non-user-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("assistant"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
assistantMessage := dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: assistantContent,
|
|
})
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: assistantMessage.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
|
|
}
|
|
|
|
// TestEditMessageDebugCleanupDeletesPreEditRuns verifies that
|
|
// EditMessage schedules the chat debug cleanup goroutine when debug
|
|
// logging is enabled and that it deletes debug runs tied to the
|
|
// pre-edit conversation branch. This exercises the chatd wiring end
|
|
// to end: lazy debugService init, editCutoff sampling from the DB,
|
|
// and the scheduleDebugCleanup retry loop against a real Postgres
|
|
// store.
|
|
func TestEditMessageDebugCleanupDeletesPreEditRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newDebugEnabledTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "debug-edit-cleanup",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID, AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, msgs, 1)
|
|
editedMsgID := msgs[0].ID
|
|
|
|
// Stale debug run tied to the pre-edit message branch. Stamped
|
|
// well outside the clock-skew buffer so the fast retry path
|
|
// deletes it instead of deferring to the stale sweeper.
|
|
staleStart := time.Now().Add(-time.Hour).UTC().Truncate(time.Microsecond)
|
|
staleRun, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: editedMsgID, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: editedMsgID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: "openai", Valid: true},
|
|
Model: sql.NullString{String: model.Model, Valid: true},
|
|
StartedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Run tied to an earlier message branch that the message-id
|
|
// filter should leave alone even though it predates the edit.
|
|
unrelatedRun, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: editedMsgID - 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: editedMsgID - 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "completed",
|
|
Provider: sql.NullString{String: "openai", Valid: true},
|
|
Model: sql.NullString{String: model.Model, Valid: true},
|
|
StartedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMsgID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatd.WaitUntilIdleForTest(replica)
|
|
|
|
// ErrNoRows on staleRun proves the fast-retry path DELETED the
|
|
// row: FinalizeStale (the only other debug-row writer on the
|
|
// server) only UPDATEs finished_at in place, it never deletes,
|
|
// so the row can only disappear via DeleteAfterMessageID which
|
|
// is reached solely from scheduleDebugCleanup.
|
|
_, err = db.GetChatDebugRunByID(ctx, staleRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"pre-edit run matching the message-id filter should be deleted")
|
|
|
|
remaining, err := db.GetChatDebugRunByID(ctx, unrelatedRun.ID)
|
|
require.NoError(t, err,
|
|
"runs outside the edited message branch must survive cleanup")
|
|
require.Equal(t, unrelatedRun.ID, remaining.ID)
|
|
|
|
// Count the seeded rows that survive so the delete count is
|
|
// verified directly (not just by negative lookup). Scoped to
|
|
// seeded IDs because the processor may start a new chat_turn
|
|
// run in parallel when EditMessage transitions the chat back to
|
|
// pending.
|
|
remainingRuns, err := db.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID, LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
seeded := map[uuid.UUID]bool{staleRun.ID: true, unrelatedRun.ID: true}
|
|
survivors := 0
|
|
for _, r := range remainingRuns {
|
|
if seeded[r.ID] {
|
|
survivors++
|
|
}
|
|
}
|
|
require.Equal(t, 1, survivors,
|
|
"exactly one of the two seeded runs should survive (the unrelated run)")
|
|
}
|
|
|
|
// TestEditMessageDebugCleanupPreservesRecentRuns verifies that the
|
|
// clock-skew buffer in the edit-cleanup cutoff prevents the fast
|
|
// retry from deleting debug runs that started within the buffer
|
|
// window. The stale sweep handles those leftovers later.
|
|
func TestEditMessageDebugCleanupPreservesRecentRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newDebugEnabledTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "debug-edit-buffer",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID, AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, msgs, 1)
|
|
editedMsgID := msgs[0].ID
|
|
|
|
// Within the 30s skew buffer, so the fast retry must leave it
|
|
// alone even though its message ID matches the delete filter.
|
|
recentStart := time.Now().Add(-time.Second).UTC().Truncate(time.Microsecond)
|
|
recentRun, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: editedMsgID, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: editedMsgID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: "openai", Valid: true},
|
|
Model: sql.NullString{String: model.Model, Valid: true},
|
|
StartedAt: sql.NullTime{Time: recentStart, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: recentStart, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMsgID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatd.WaitUntilIdleForTest(replica)
|
|
|
|
remaining, err := db.GetChatDebugRunByID(ctx, recentRun.ID)
|
|
require.NoError(t, err,
|
|
"runs inside the clock-skew buffer must survive the fast retry")
|
|
require.Equal(t, recentRun.ID, remaining.ID)
|
|
|
|
// If the clock-skew buffer were removed the fast retry would
|
|
// have deleted recentRun. Verify the count of seeded survivors
|
|
// directly, ignoring any new chat_turn run the processor may
|
|
// create after the pending status transition.
|
|
remainingRuns, err := db.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID, LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
survivors := 0
|
|
for _, r := range remainingRuns {
|
|
if r.ID == recentRun.ID {
|
|
survivors++
|
|
}
|
|
}
|
|
require.Equal(t, 1, survivors,
|
|
"the buffered run must survive the fast retry")
|
|
}
|
|
|
|
// TestArchiveChatDebugCleanupDeletesPreArchiveRuns verifies that
|
|
// ArchiveChat schedules cleanup that deletes pre-archive debug runs
|
|
// for the archived chat. Covers the archiveCutoff sampled from
|
|
// ArchiveChatByID's DB-stamped updated_at and the DeleteByChatID
|
|
// delete path.
|
|
func TestArchiveChatDebugCleanupDeletesPreArchiveRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newDebugEnabledTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "debug-archive-cleanup",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
staleStart := time.Now().Add(-time.Hour).UTC().Truncate(time.Microsecond)
|
|
staleRun, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: "openai", Valid: true},
|
|
Model: sql.NullString{String: model.Model, Valid: true},
|
|
StartedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleStart, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Freshly-inserted run inside the skew buffer must survive the
|
|
// fast retry for the same reason as the edit-cleanup buffer test.
|
|
recentStart := time.Now().Add(-time.Second).UTC().Truncate(time.Microsecond)
|
|
recentRun, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: "openai", Valid: true},
|
|
Model: sql.NullString{String: model.Model, Valid: true},
|
|
StartedAt: sql.NullTime{Time: recentStart, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: recentStart, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
chatd.WaitUntilIdleForTest(replica)
|
|
|
|
// ErrNoRows proves the fast-retry path DELETED the row:
|
|
// FinalizeStale only UPDATEs in place, never deletes.
|
|
_, err = db.GetChatDebugRunByID(ctx, staleRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"pre-archive run outside the buffer should be deleted")
|
|
|
|
remaining, err := db.GetChatDebugRunByID(ctx, recentRun.ID)
|
|
require.NoError(t, err,
|
|
"runs inside the clock-skew buffer must survive the fast retry")
|
|
require.Equal(t, recentRun.ID, remaining.ID)
|
|
|
|
// Count the seeded survivors directly so the delete is verified
|
|
// not just by absence of a specific row. Scoped to seeded IDs
|
|
// because the archive transition may still race with other
|
|
// background debug writes.
|
|
remainingRuns, err := db.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID, LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
seeded := map[uuid.UUID]bool{staleRun.ID: true, recentRun.ID: true}
|
|
survivors := 0
|
|
for _, r := range remainingRuns {
|
|
if seeded[r.ID] {
|
|
survivors++
|
|
}
|
|
}
|
|
require.Equal(t, 1, survivors,
|
|
"only the recent (buffered) seeded run should survive")
|
|
}
|
|
|
|
func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
// Use a very short stale threshold so the periodic recovery
|
|
// kicks in quickly during the test.
|
|
staleAfter := 500 * time.Millisecond
|
|
|
|
// Create a chat and simulate a dead worker by setting the chat
|
|
// to running with a heartbeat in the past.
|
|
deadWorkerID := uuid.New()
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica. Its startup recovery will reset the
|
|
// chat (since the heartbeat is old), but the key point is that
|
|
// the periodic loop also recovers newly-stale chats.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// The startup recovery should have already reset our stale
|
|
// chat.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
// Now simulate a second stale chat appearing AFTER startup.
|
|
// This tests the periodic recovery, not just the startup one.
|
|
deadWorkerID2 := uuid.New()
|
|
chat2 := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic-2",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat2.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID2, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The periodic stale recovery loop (running at staleAfter/5 =
|
|
// 100ms intervals) should pick this up without a restart.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat2.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestRecoverStaleRequiresActionChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
// Use a very short stale threshold so the periodic recovery
|
|
// kicks in quickly during the test.
|
|
staleAfter := 500 * time.Millisecond
|
|
|
|
// Create a chat and set it to requires_action to simulate a
|
|
// client that disappeared while the chat was waiting for
|
|
// dynamic tool results.
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "stale-requires-action",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRequiresAction,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Backdate updated_at so the chat appears stale to the
|
|
// recovery loop without needing time.Sleep.
|
|
_, err = rawDB.ExecContext(ctx,
|
|
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
|
time.Now().Add(-time.Hour), chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// The stale recovery should transition the requires_action
|
|
// chat to error with the timeout message.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
chatResult, err = db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return chatResult.Status == database.ChatStatusError
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
persistedError := requireChatLastErrorPayload(t, chatResult.LastError)
|
|
require.Equal(t, codersdk.ChatError{
|
|
Message: "Dynamic tool execution timed out",
|
|
Kind: codersdk.ChatErrorKindGeneric,
|
|
}, persistedError)
|
|
require.False(t, chatResult.WorkerID.Valid)
|
|
}
|
|
|
|
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
// Simulate a chat left running by a dead replica with a stale
|
|
// heartbeat (well beyond the stale threshold).
|
|
deadReplicaID := uuid.New()
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "orphaned-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
// Set the heartbeat far in the past so it's definitely stale.
|
|
_, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadReplicaID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica. It should recover the stale chat on
|
|
// startup.
|
|
newReplica := newTestServer(t, db, ps, uuid.New())
|
|
_ = newReplica
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
// Create a chat in waiting status. This should NOT be touched
|
|
// by stale recovery.
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "waiting-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
// Start a replica with a short stale threshold.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: 500 * time.Millisecond,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Wait long enough for multiple periodic recovery cycles to
|
|
// run (staleAfter/5 = 100ms intervals).
|
|
require.Never(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status != database.ChatStatusWaiting
|
|
}, time.Second, testutil.IntervalFast,
|
|
"waiting chat should not be modified by stale recovery")
|
|
}
|
|
|
|
func TestUpdateChatStatusPersistsLastError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
_ = newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "error-persisted",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
// Write a minimal structured last_error payload through the
|
|
// query layer, then verify it round-trips through storage.
|
|
errorMessage := "stream response: status 500: internal server error"
|
|
wantPayload := codersdk.ChatError{
|
|
Message: errorMessage,
|
|
Kind: codersdk.ChatErrorKindGeneric,
|
|
}
|
|
chat, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusError,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: mustChatLastErrorRawMessage(t, wantPayload),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, chat.Status)
|
|
require.Equal(t, wantPayload, requireChatLastErrorPayload(t, chat.LastError))
|
|
|
|
// Verify the error is persisted when re-read from the database.
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, fromDB.Status)
|
|
require.Equal(t, wantPayload, requireChatLastErrorPayload(t, fromDB.LastError))
|
|
|
|
// Verify the error is cleared when the chat transitions to a
|
|
// non-error status (e.g. pending after a retry).
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, chat.Status)
|
|
require.False(t, chat.LastError.Valid)
|
|
|
|
fromDB, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fromDB.LastError.Valid)
|
|
}
|
|
|
|
func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "status-snapshot",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Passive server: status is always Pending.
|
|
require.NotEmpty(t, snapshot)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
|
|
require.NotNil(t, snapshot[0].Status)
|
|
}
|
|
|
|
func TestPersistToolResultWithBinaryData(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const binaryOutputBase64 = "SEVBREVSAAAAc29tZSBkYXRhAABtb3JlIGRhdGEARU5E"
|
|
binaryOutput, err := io.ReadAll(base64.NewDecoder(
|
|
base64.StdEncoding,
|
|
strings.NewReader(binaryOutputBase64),
|
|
))
|
|
require.NoError(t, err)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Binary tool result test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"execute",
|
|
`{"command":"cat /home/coder/binary_file.bin"}`,
|
|
),
|
|
)
|
|
}
|
|
// Include literal \u0000 in the response text, which is
|
|
// what a real LLM writes when explaining binary output.
|
|
// json.Marshal encodes the backslash as \\, producing
|
|
// \\u0000 in the JSON bytes. The sanitizer must not
|
|
// corrupt this into invalid JSON.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
|
|
)
|
|
})
|
|
|
|
// Use "openai-compat" provider so the chatd framework uses the
|
|
// /chat/completions endpoint, where the mock server supports
|
|
// streaming tool calls. The default "openai" provider routes to
|
|
// /responses which only handles text deltas in the mock.
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().
|
|
SetExtraHeaders(gomock.Any()).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
StartProcess(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
|
|
require.Equal(t, "cat /home/coder/binary_file.bin", req.Command)
|
|
return workspacesdk.StartProcessResponse{ID: "proc-binary", Started: true}, nil
|
|
}).
|
|
Times(1)
|
|
mockConn.EXPECT().
|
|
ProcessOutput(gomock.Any(), "proc-binary", gomock.Any()).
|
|
Return(workspacesdk.ProcessOutputResponse{
|
|
Output: string(binaryOutput),
|
|
Running: false,
|
|
ExitCode: ptrRef(0),
|
|
}, nil).
|
|
AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "binary-tool-result",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Read /home/coder/binary_file.bin."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
var toolMessage *database.ChatMessage
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for i := range messages {
|
|
if messages[i].Role == database.ChatMessageRoleTool {
|
|
toolMessage = &messages[i]
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.NotNil(t, toolMessage)
|
|
|
|
parts, err := chatprompt.ParseContent(*toolMessage)
|
|
require.NoError(t, err)
|
|
require.Len(t, parts, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
|
|
require.Equal(t, "execute", parts[0].ToolName)
|
|
|
|
var result chattool.ExecuteResult
|
|
require.NoError(t, json.Unmarshal(parts[0].Result, &result))
|
|
require.True(t, result.Success)
|
|
require.Equal(t, string(binaryOutput), result.Output)
|
|
require.Equal(t, 0, result.ExitCode)
|
|
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result chattool.ExecuteResult
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
if result.Output == string(binaryOutput) {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
|
|
}
|
|
|
|
func TestRequiresActionChatPersistsWaitingStatusLabel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello world"}`,
|
|
),
|
|
)
|
|
})
|
|
|
|
mockPush := &mockWebpushDispatcher{}
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "requires-action-status-label",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Please call the dynamic tool."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
seedLastTurnSummary(ctx, t, db, chat, "previous summary")
|
|
|
|
server.Start()
|
|
|
|
var fromDB database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
got, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
fromDB = got
|
|
if got.Status == database.ChatStatusError {
|
|
return true
|
|
}
|
|
return got.Status == database.ChatStatusRequiresAction &&
|
|
got.LastTurnSummary.Valid &&
|
|
got.LastTurnSummary.String == "Waiting for user input"
|
|
}, testutil.IntervalFast)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
require.Equal(t, database.ChatStatusRequiresAction, fromDB.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
fromDB.Status, string(fromDB.LastError.RawMessage))
|
|
require.Equal(t, sql.NullString{String: "Waiting for user input", Valid: true}, fromDB.LastTurnSummary,
|
|
"requires action chats should persist a waiting status label")
|
|
require.Equal(t, int32(0), mockPush.dispatchCount.Load(),
|
|
"expected no web push dispatch for a requires_action chat")
|
|
}
|
|
|
|
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Track streaming calls to the mock LLM.
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
// Non-streaming requests are title generation. Return a
|
|
// simple title.
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
|
|
}
|
|
|
|
// Capture the full request for later assertions.
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
|
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
|
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
|
Stream: req.Stream,
|
|
})
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
// First call: the LLM invokes our dynamic tool.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello world"}`,
|
|
),
|
|
)
|
|
}
|
|
// Second call: the LLM returns a normal text response.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Dynamic tool result received.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Dynamic tools do not need a workspace connection, but the
|
|
// chatd server always builds workspace tools. Use an active
|
|
// server without an agent connection, so the built-in tools
|
|
// are never invoked because the only tool call targets our
|
|
// dynamic tool.
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
// Create a chat with a dynamic tool.
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "dynamic-tool-pause-resume",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Please call the dynamic tool."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 1. Wait for the chat to reach requires_action status.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusRequiresAction ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
chatResult.Status, chatLastErrorMessage(chatResult.LastError))
|
|
|
|
// 2. Read the assistant message to find the tool-call ID.
|
|
var toolCallID string
|
|
var toolCallFound bool
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleAssistant {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
|
toolCallID = part.ToolCallID
|
|
toolCallFound = true
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
|
|
require.NotEmpty(t, toolCallID)
|
|
|
|
// 3. Submit tool results via SubmitToolResults.
|
|
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
|
|
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
|
ChatID: chat.ID,
|
|
UserID: user.ID,
|
|
ModelConfigID: chatResult.LastModelConfigID,
|
|
Results: []codersdk.ToolResult{{
|
|
ToolCallID: toolCallID,
|
|
Output: toolResultOutput,
|
|
}},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 4. Wait for the chat to reach a terminal status.
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
// 5. Verify the chat completed successfully.
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// 6. Verify the mock received exactly 2 streaming calls.
|
|
require.Equal(t, int32(2), streamedCallCount.Load(),
|
|
"expected exactly 2 streaming calls to the LLM")
|
|
|
|
streamedCallsMu.Lock()
|
|
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.Len(t, recordedCalls, 2)
|
|
|
|
// 7. Verify the dynamic tool appeared in the first call's tool list.
|
|
var foundDynamicTool bool
|
|
for _, tool := range recordedCalls[0].Tools {
|
|
if tool.Function.Name == "my_dynamic_tool" {
|
|
foundDynamicTool = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundDynamicTool,
|
|
"expected 'my_dynamic_tool' in the first LLM call's tool list")
|
|
|
|
// 8. Verify the second call's messages contain the tool result.
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedCalls[1].Messages {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if strings.Contains(message.Content, "dynamic tool output") {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall,
|
|
"expected second LLM call to include the submitted dynamic tool result")
|
|
}
|
|
|
|
func TestDynamicToolNamedProposePlanRemainsAvailableOutsidePlanMode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([]chattest.OpenAIRequest, 0, 1)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Dynamic tool collision test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
|
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
|
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
|
Stream: req.Stream,
|
|
})
|
|
streamedCallsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Dynamic tool list captured.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "propose_plan",
|
|
Description: "A dynamic tool whose name collides with the hidden built-in.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "dynamic-propose-plan-collision",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("List the available tools."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.NotEmpty(t, recordedCalls)
|
|
|
|
var foundDynamicTool bool
|
|
for _, tool := range recordedCalls[0].Tools {
|
|
if tool.Function.Name == "propose_plan" {
|
|
foundDynamicTool = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundDynamicTool,
|
|
"expected the dynamic propose_plan tool to remain visible outside plan mode")
|
|
}
|
|
|
|
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Track streaming calls to the mock LLM.
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Mixed tool test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
|
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
|
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
|
Stream: req.Stream,
|
|
})
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
// First call: return TWO tool calls in one
|
|
// response: a built-in tool (read_file) and a
|
|
// dynamic tool (my_dynamic_tool).
|
|
builtinChunk := chattest.OpenAIToolCallChunk(
|
|
"read_file",
|
|
`{"path":"/tmp/test.txt"}`,
|
|
)
|
|
dynamicChunk := chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello world"}`,
|
|
)
|
|
// Merge both tool calls into one chunk with
|
|
// separate indices so the LLM appears to have
|
|
// requested both tools simultaneously.
|
|
mergedChunk := builtinChunk
|
|
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
|
|
dynCall.Index = 1
|
|
mergedChunk.Choices[0].ToolCalls = append(
|
|
mergedChunk.Choices[0].ToolCalls,
|
|
dynCall,
|
|
)
|
|
return chattest.OpenAIStreamingResponse(mergedChunk)
|
|
}
|
|
// Second call (after tool results): normal text
|
|
// response.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("All done.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
// Create a chat with a dynamic tool.
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "mixed-builtin-dynamic",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Call both tools."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 1. Wait for the chat to reach requires_action status.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusRequiresAction ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
chatResult.Status, chatLastErrorMessage(chatResult.LastError))
|
|
|
|
// 2. Verify the built-in tool (read_file) was already
|
|
// executed by checking that a tool result message
|
|
// exists for it in the database.
|
|
var builtinToolResultFound bool
|
|
var toolCallID string
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
// Check for the built-in tool result.
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
|
|
builtinToolResultFound = true
|
|
}
|
|
// Find the dynamic tool call ID.
|
|
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
|
toolCallID = part.ToolCallID
|
|
}
|
|
}
|
|
}
|
|
return builtinToolResultFound && toolCallID != ""
|
|
}, testutil.IntervalFast)
|
|
|
|
require.True(t, builtinToolResultFound,
|
|
"expected read_file tool result in the DB before dynamic tool resolution")
|
|
require.NotEmpty(t, toolCallID)
|
|
|
|
// 3. Submit dynamic tool results.
|
|
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
|
ChatID: chat.ID,
|
|
UserID: user.ID,
|
|
ModelConfigID: chatResult.LastModelConfigID,
|
|
Results: []codersdk.ToolResult{{
|
|
ToolCallID: toolCallID,
|
|
Output: json.RawMessage(`{"result":"dynamic output"}`),
|
|
}},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 4. Wait for the chat to complete.
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// 5. Verify the LLM received exactly 2 streaming calls.
|
|
require.Equal(t, int32(2), streamedCallCount.Load(),
|
|
"expected exactly 2 streaming calls to the LLM")
|
|
}
|
|
|
|
func TestSubmitToolResultsConcurrency(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// The mock LLM returns a dynamic tool call on the first streaming
|
|
// request, then a plain text reply on the second.
|
|
var streamedCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Concurrency test")
|
|
}
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello"}`,
|
|
),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Done.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
// Create a chat with a dynamic tool.
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "concurrency-tool-results",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Please call the dynamic tool."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to reach requires_action status.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusRequiresAction ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
chatResult.Status, chatLastErrorMessage(chatResult.LastError))
|
|
|
|
// Find the tool call ID from the assistant message.
|
|
var toolCallID string
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleAssistant {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
|
toolCallID = part.ToolCallID
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.NotEmpty(t, toolCallID)
|
|
|
|
// Spawn N goroutines that all try to submit tool results at the
|
|
// same time. Exactly one should succeed; the rest must get a
|
|
// ToolResultStatusConflictError.
|
|
const numGoroutines = 10
|
|
var (
|
|
wg sync.WaitGroup
|
|
ready = make(chan struct{})
|
|
successes atomic.Int32
|
|
conflicts atomic.Int32
|
|
unexpectedErrors = make(chan error, numGoroutines)
|
|
)
|
|
|
|
for range numGoroutines {
|
|
wg.Go(func() {
|
|
// Wait for all goroutines to be ready.
|
|
<-ready
|
|
|
|
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
|
ChatID: chat.ID,
|
|
UserID: user.ID,
|
|
ModelConfigID: chatResult.LastModelConfigID,
|
|
Results: []codersdk.ToolResult{{
|
|
ToolCallID: toolCallID,
|
|
Output: json.RawMessage(`{"result":"concurrent output"}`),
|
|
}},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
|
|
if submitErr == nil {
|
|
successes.Add(1)
|
|
return
|
|
}
|
|
var conflict *chatd.ToolResultStatusConflictError
|
|
if errors.As(submitErr, &conflict) {
|
|
conflicts.Add(1)
|
|
return
|
|
}
|
|
// Collect unexpected errors for assertion
|
|
// outside the goroutine (require.NoError
|
|
// calls t.FailNow which is illegal here).
|
|
unexpectedErrors <- submitErr
|
|
})
|
|
}
|
|
// Release all goroutines at once.
|
|
close(ready)
|
|
|
|
wg.Wait()
|
|
close(unexpectedErrors)
|
|
|
|
for ue := range unexpectedErrors {
|
|
require.NoError(t, ue, "unexpected error from SubmitToolResults")
|
|
}
|
|
|
|
require.Equal(t, int32(1), successes.Load(),
|
|
"expected exactly 1 goroutine to succeed")
|
|
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
|
|
"expected %d conflict errors", numGoroutines-1)
|
|
}
|
|
|
|
func ptrRef[T any](v T) *T {
|
|
return &v
|
|
}
|
|
|
|
func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use nil pubsub to force the no-pubsub path.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
replica := newStartedTestServer(t, db, nil, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "no-dup-parts",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for any wake-triggered processing to settle before
|
|
// subscribing, so the snapshot captures the final state.
|
|
// The wake signal may trigger processOnce which will fail
|
|
// (no LLM configured) and set the chat to error status.
|
|
// Poll until the chat reaches a terminal state (not pending
|
|
// and not running), then wait for the goroutine to finish.
|
|
waitForChatProcessed(ctx, t, db, chat.ID, replica)
|
|
|
|
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Snapshot should have events (at minimum: status + message).
|
|
require.NotEmpty(t, snapshot)
|
|
|
|
// The events channel should NOT immediately produce any
|
|
// events. The snapshot already contained everything. Before
|
|
// the fix, localSnapshot was replayed into the channel,
|
|
// causing duplicates.
|
|
require.Never(t, func() bool {
|
|
select {
|
|
case <-events:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, 200*time.Millisecond, testutil.IntervalFast,
|
|
"expected no duplicate events after snapshot")
|
|
}
|
|
|
|
func TestSubscribeAfterMessageID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: model.ID,
|
|
Title: "after-id-test",
|
|
Status: database.ChatStatusWaiting,
|
|
})
|
|
|
|
// Seed all messages directly so this subscription test is independent
|
|
// of chat processing lifecycle behavior.
|
|
firstContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("first"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleUser,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: firstContent,
|
|
})
|
|
|
|
secondContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("second"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
msg2 := dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleAssistant,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: secondContent,
|
|
})
|
|
|
|
thirdContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("third"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: database.ChatMessageRoleUser,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Content: thirdContent,
|
|
})
|
|
|
|
// Control: Subscribe with afterMessageID=0 returns ALL messages.
|
|
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
cancelAll()
|
|
|
|
allMessages := filterMessageEvents(allSnapshot)
|
|
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
|
|
|
|
// Subscribe with afterMessageID set to the second message's ID.
|
|
// Only the third message (inserted after msg2) should appear.
|
|
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
|
|
require.True(t, ok)
|
|
cancelPartial()
|
|
|
|
partialMessages := filterMessageEvents(partialSnapshot)
|
|
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
|
|
require.Equal(t, codersdk.ChatMessageRoleUser, partialMessages[0].Message.Role)
|
|
}
|
|
|
|
// filterMessageEvents returns only the Message-type events from a
|
|
// snapshot slice, which is useful for ignoring status / queue events.
|
|
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
|
|
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
|
|
return e.Type == codersdk.ChatStreamEventTypeMessage
|
|
})
|
|
}
|
|
|
|
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
agentToken := uuid.NewString()
|
|
// Add a startup script so the agent spends time in the
|
|
// "starting" lifecycle state. This lets us verify that
|
|
// create_workspace waits for scripts to finish.
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken, func(g *proto.GraphComplete) {
|
|
g.Resources[0].Agents[0].Scripts = []*proto.Script{{
|
|
DisplayName: "setup",
|
|
Script: "sleep 5",
|
|
RunOnStart: true,
|
|
}}
|
|
}),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
// Start the test workspace agent so create_workspace can wait for
|
|
// the agent to become reachable before returning.
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
workspaceName := "chat-ws-" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
|
|
createWorkspaceArgs := fmt.Sprintf(
|
|
`{"template_id":%q,"name":%q}`,
|
|
template.ID.String(),
|
|
workspaceName,
|
|
)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Create workspace test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace", createWorkspaceArgs),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Workspace created and ready.")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Create a workspace from the template and continue.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatResult codersdk.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := expClient.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == codersdk.ChatStatusError {
|
|
lastError := ""
|
|
if chatResult.LastError != nil {
|
|
lastError = chatResult.LastError.Message
|
|
}
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
|
}
|
|
|
|
require.NotNil(t, chatResult.WorkspaceID)
|
|
workspaceID := *chatResult.WorkspaceID
|
|
workspace, err := client.Workspace(ctx, workspaceID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceName, workspace.Name)
|
|
|
|
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
|
require.NoError(t, err)
|
|
|
|
var foundCreateWorkspaceResult bool
|
|
for _, message := range chatMsgs.Messages {
|
|
if message.Role != codersdk.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
for _, part := range message.Content {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "create_workspace" {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
require.NoError(t, json.Unmarshal(part.Result, &result))
|
|
created, ok := result["created"].(bool)
|
|
require.True(t, ok)
|
|
require.True(t, created)
|
|
foundCreateWorkspaceResult = true
|
|
}
|
|
}
|
|
require.True(t, foundCreateWorkspaceResult, "expected create_workspace tool result message")
|
|
|
|
// Verify that the tool waited for startup scripts to
|
|
// complete. The agent should be in "ready" state by the
|
|
// time create_workspace returns its result.
|
|
workspace, err = client.Workspace(ctx, workspaceID)
|
|
require.NoError(t, err)
|
|
var agentLifecycle codersdk.WorkspaceAgentLifecycle
|
|
for _, res := range workspace.LatestBuild.Resources {
|
|
for _, agt := range res.Agents {
|
|
agentLifecycle = agt.LifecycleState
|
|
}
|
|
}
|
|
require.Equal(t, codersdk.WorkspaceAgentLifecycleReady, agentLifecycle,
|
|
"agent should be ready after create_workspace returns; startup scripts were not awaited")
|
|
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
created, ok := result["created"].(bool)
|
|
if ok && created {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include create_workspace tool output")
|
|
}
|
|
|
|
func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
// Create a workspace, then stop it so start_workspace has
|
|
// something to start. We intentionally skip starting a test
|
|
// agent. The echo provisioner creates new agent rows for each
|
|
// build, so an agent started for build 1 cannot serve build 3.
|
|
// The tool handles the no-agent case gracefully.
|
|
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
|
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
|
workspace = coderdtest.MustTransitionWorkspace(
|
|
t, client, workspace.ID,
|
|
codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop,
|
|
)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Start workspace test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("start_workspace", "{}"),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Workspace started and ready.")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
// Create a chat with the stopped workspace pre-associated.
|
|
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Start the workspace.",
|
|
},
|
|
},
|
|
WorkspaceID: &workspace.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatResult codersdk.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := expClient.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitSuperLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == codersdk.ChatStatusError {
|
|
lastError := ""
|
|
if chatResult.LastError != nil {
|
|
lastError = chatResult.LastError.Message
|
|
}
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
|
}
|
|
|
|
// Verify the workspace was started.
|
|
require.NotNil(t, chatResult.WorkspaceID)
|
|
updatedWorkspace, err := client.Workspace(ctx, workspace.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
|
|
|
|
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Verify start_workspace tool result exists in the chat messages.
|
|
var foundStartWorkspaceResult bool
|
|
for _, message := range chatMsgs.Messages {
|
|
if message.Role != codersdk.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
for _, part := range message.Content {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "start_workspace" {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
require.NoError(t, json.Unmarshal(part.Result, &result))
|
|
started, ok := result["started"].(bool)
|
|
require.True(t, ok)
|
|
require.True(t, started)
|
|
foundStartWorkspaceResult = true
|
|
}
|
|
}
|
|
require.True(t, foundStartWorkspaceResult, "expected start_workspace tool result message")
|
|
|
|
// Verify the LLM received the tool result in its second call.
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
started, ok := result["started"].(bool)
|
|
if ok && started {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include start_workspace tool output")
|
|
}
|
|
|
|
func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
toolsByCall := make([][]string, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Stopped workspace regression")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
toolsByCall = append(toolsByCall, names)
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("execute", `{"command":"echo hi"}`),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("The workspace is unavailable. Start it before retrying workspace tools.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
inactive := newTestServer(t, db, ps, uuid.New())
|
|
chat, err := inactive.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "stopped-workspace-regression",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Run echo hi in the workspace."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Close the inactive server so its wake-triggered processing
|
|
// stops and releases the chat. Then reset to pending so the
|
|
// active server (created below) can acquire it cleanly.
|
|
require.NoError(t, inactive.Close())
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
|
require.NoError(t, err)
|
|
chat, err = db.UpdateChatBuildAgentBinding(ctx, database.UpdateChatBuildAgentBindingParams{
|
|
ID: chat.ID,
|
|
BuildID: uuid.NullUUID{UUID: build.ID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
|
Transition: database.WorkspaceTransitionStop,
|
|
BuildNumber: 2,
|
|
}).Do()
|
|
|
|
var dialCalls atomic.Int32
|
|
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalls.Add(1)
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
})
|
|
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
require.EqualValues(t, 1, dialCalls.Load())
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
|
|
streamedCallsMu.Lock()
|
|
recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
recordedTools := append([][]string(nil), toolsByCall...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedCalls), 2)
|
|
require.NotEmpty(t, recordedTools)
|
|
require.Contains(t, recordedTools[0], "execute")
|
|
require.Contains(t, recordedTools[0], "start_workspace")
|
|
|
|
var foundUnavailableToolResult bool
|
|
for _, message := range recordedCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if strings.Contains(message.Content, "workspace has no running agent") {
|
|
foundUnavailableToolResult = true
|
|
break
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var toolResult map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &toolResult); err != nil {
|
|
continue
|
|
}
|
|
errMsg, _ := toolResult["error"].(string)
|
|
outputMsg, _ := toolResult["output"].(string)
|
|
if strings.Contains(errMsg, "workspace has no running agent") ||
|
|
strings.Contains(outputMsg, "workspace has no running agent") {
|
|
foundUnavailableToolResult = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundUnavailableToolResult,
|
|
"expected the second streamed model call to include the unavailable workspace tool result")
|
|
|
|
var toolMessage *database.ChatMessage
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for i := range messages {
|
|
if messages[i].Role == database.ChatMessageRoleTool {
|
|
toolMessage = &messages[i]
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.NotNil(t, toolMessage)
|
|
|
|
parts, err := chatprompt.ParseContent(*toolMessage)
|
|
require.NoError(t, err)
|
|
require.Len(t, parts, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
|
|
require.Equal(t, "execute", parts[0].ToolName)
|
|
require.True(t, parts[0].IsError)
|
|
require.Contains(t, string(parts[0].Result), "workspace has no running agent")
|
|
}
|
|
|
|
func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
// Block until the request context is canceled so the chat
|
|
// stays in a processing state long enough for heartbeats
|
|
// to fire.
|
|
chunks := make(chan chattest.OpenAIChunk)
|
|
go func() {
|
|
defer close(chunks)
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}))
|
|
|
|
// Create a workspace with a full build chain so we can verify
|
|
// both last_used_at (dormancy) and deadline (autostop) bumps.
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
ActiveVersionID: tv.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
|
|
ID: tmpl.ID,
|
|
UpdatedAt: dbtime.Now(),
|
|
AllowUserAutostop: true,
|
|
ActivityBump: int64(time.Hour),
|
|
}))
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tmpl.ID,
|
|
Ttl: sql.NullInt64{Valid: true, Int64: int64(8 * time.Hour)},
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
CompletedAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-30 * time.Minute),
|
|
},
|
|
})
|
|
// Build deadline is 30 minutes in the past, close enough to
|
|
// be bumped by the default 1-hour activity bump.
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: tv.ID,
|
|
JobID: pj.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
Deadline: dbtime.Now().Add(-30 * time.Minute),
|
|
})
|
|
originalDeadline := build.Deadline
|
|
|
|
// Set up a short heartbeat interval and a UsageTracker that
|
|
// flushes frequently so last_used_at gets updated in the DB.
|
|
flushTick := make(chan time.Time)
|
|
flushDone := make(chan int, 1)
|
|
tracker := workspacestats.NewTracker(db,
|
|
workspacestats.TrackerWithTickFlush(flushTick, flushDone),
|
|
workspacestats.TrackerWithLogger(slogtest.Make(t, nil)),
|
|
)
|
|
t.Cleanup(func() { tracker.Close() })
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
// Wrap the database with dbauthz so the chatd server's
|
|
// AsChatd context is enforced on every query, matching
|
|
// production behavior.
|
|
authzDB := dbauthz.New(db, rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()), slogtest.Make(t, nil), coderdtest.AccessControlStorePointer())
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: authzDB,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
ChatHeartbeatInterval: 100 * time.Millisecond,
|
|
UsageTracker: tracker,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Create a chat WITHOUT a workspace, the normal starting state.
|
|
// In production, CreateChat is called from the HTTP handler with
|
|
// the authenticated user's context. Here we use AsChatd since
|
|
// the chatd server processes everything under that role.
|
|
chatCtx := dbauthz.AsChatd(ctx)
|
|
chat, err := server.CreateChat(chatCtx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "usage-tracking-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to start processing and at least one
|
|
// heartbeat to fire.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, listErr := db.GetChatByID(ctx, chat.ID)
|
|
if listErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning &&
|
|
fromDB.HeartbeatAt.Valid &&
|
|
fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt)
|
|
}, testutil.IntervalFast,
|
|
"chat should be running with at least one heartbeat")
|
|
|
|
// Flush the tracker and verify nothing was tracked yet
|
|
// (no workspace linked).
|
|
testutil.RequireSend(ctx, t, flushTick, time.Now())
|
|
count := testutil.RequireReceive(ctx, t, flushDone)
|
|
require.Equal(t, 0, count,
|
|
"expected no workspaces to be flushed before association")
|
|
|
|
// Link the workspace to the chat in the DB, simulating what
|
|
// the create_workspace tool does mid-conversation.
|
|
_, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
ID: chat.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The heartbeat re-reads the workspace association from the DB
|
|
// on each tick. Wait for the tracker to pick it up.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case flushTick <- time.Now():
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
select {
|
|
case c := <-flushDone:
|
|
return c > 0
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
}, testutil.IntervalMedium,
|
|
"expected usage tracker to flush the late-associated workspace")
|
|
|
|
// Verify the workspace's last_used_at was actually updated.
|
|
updatedWs, err := db.GetWorkspaceByID(ctx, ws.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, updatedWs.LastUsedAt.After(ws.LastUsedAt),
|
|
"workspace last_used_at should have been bumped")
|
|
|
|
// Verify the workspace build deadline was also extended.
|
|
// The SQL only writes when 5% of the deadline has elapsed,
|
|
// most calls perform a read-only CTE lookup. Wider ±2
|
|
// minute tolerance than activitybump_test.go because the bump
|
|
// happens asynchronously via the heartbeat goroutine.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
updatedBuild, buildErr := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
|
if buildErr != nil || !updatedBuild.Deadline.After(originalDeadline) {
|
|
return false
|
|
}
|
|
now := dbtime.Now()
|
|
return updatedBuild.Deadline.After(now.Add(time.Hour-2*time.Minute)) &&
|
|
updatedBuild.Deadline.Before(now.Add(time.Hour+2*time.Minute))
|
|
}, testutil.IntervalFast,
|
|
"workspace build deadline should have been bumped to ~now+1h")
|
|
}
|
|
|
|
func TestHeartbeatNoWorkspaceNoBump(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk)
|
|
go func() {
|
|
defer close(chunks)
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}))
|
|
|
|
// Set up UsageTracker with manual tick/flush.
|
|
usageTickCh := make(chan time.Time)
|
|
flushCh := make(chan int, 1)
|
|
tracker := workspacestats.NewTracker(db,
|
|
workspacestats.TrackerWithTickFlush(usageTickCh, flushCh),
|
|
workspacestats.TrackerWithLogger(slogtest.Make(t, nil)),
|
|
)
|
|
t.Cleanup(func() { tracker.Close() })
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
ChatHeartbeatInterval: 100 * time.Millisecond,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Create a chat WITHOUT linking a workspace.
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "no-workspace-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to be acquired and at least one heartbeat
|
|
// to fire.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, listErr := db.GetChatByID(ctx, chat.ID)
|
|
if listErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning &&
|
|
fromDB.HeartbeatAt.Valid &&
|
|
fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt)
|
|
}, testutil.IntervalFast,
|
|
"chat should be running with at least one heartbeat")
|
|
|
|
// Flush the tracker. Since no workspace was linked, count
|
|
// should be 0.
|
|
testutil.RequireSend(ctx, t, usageTickCh, time.Now())
|
|
count := testutil.RequireReceive(ctx, t, flushCh)
|
|
require.Equal(t, 0, count, "expected no workspaces to be flushed when chat has no workspace")
|
|
}
|
|
|
|
// waitForChatProcessed waits for a wake-triggered processOnce to
|
|
// fully complete for the given chat. It polls until the chat leaves
|
|
// both pending and running states (meaning processChat has finished
|
|
// its cleanup and updated the DB), then calls WaitUntilIdleForTest.
|
|
//
|
|
// Waiting for a terminal state (not just "not pending") avoids a
|
|
// WaitGroup Add/Wait race: AcquireChats changes the DB status to
|
|
// running before processOnce calls inflight.Add(1). If we only
|
|
// waited for status != pending, we could call Wait() while Add(1)
|
|
// hasn't happened yet.
|
|
func waitForChatProcessed(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
chatID uuid.UUID,
|
|
server *chatd.Server,
|
|
) {
|
|
t.Helper()
|
|
require.Eventually(t, func() bool {
|
|
c, err := db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
// Wait until the chat reaches a terminal state. Neither
|
|
// pending (waiting to be acquired) nor running (being
|
|
// processed). This guarantees that inflight.Add(1) has
|
|
// already been called by processOnce.
|
|
return c.Status != database.ChatStatusPending &&
|
|
c.Status != database.ChatStatusRunning
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
}
|
|
|
|
// newTestServer creates a passive server that never calls
|
|
// processOnce on its own.
|
|
func newTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
func TestPassiveServerDoesNotProcess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, ps := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
server := newTestServer(t, db, ps, uuid.New())
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "should-stay-pending",
|
|
InitialUserContent: []codersdk.ChatMessagePart{{Type: codersdk.ChatMessagePartTypeText, Text: "hello"}},
|
|
ModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
// Re-read from DB to catch any unexpected state transition.
|
|
stored, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, stored.Status)
|
|
}
|
|
|
|
// newStartedTestServer creates a server with Start() called.
|
|
// Uses a long acquire interval so processing is triggered by
|
|
// wake signals, not polling.
|
|
func newStartedTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
// newDebugEnabledTestServer creates a passive test server with
|
|
// AlwaysEnableDebugLogs=true so that IsEnabled(ctx, chatID, ownerID)
|
|
// always returns true regardless of runtime admin config. This lets
|
|
// chatd-level integration tests exercise the debug cleanup wiring
|
|
// without seeding the admin/user opt-in settings tables.
|
|
func newDebugEnabledTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
AlwaysEnableDebugLogs: true,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
// newActiveTestServer creates a chatd server that actively polls for
|
|
// and processes pending chats. Use this instead of newTestServer when
|
|
// the test needs the chat loop to actually run. Optional config
|
|
// overrides are applied after the defaults.
|
|
func newActiveTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
overrides ...func(*chatd.Config),
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
cfg := chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
}
|
|
for _, o := range overrides {
|
|
o(&cfg)
|
|
}
|
|
server := chatd.New(cfg)
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
func TestProposeChatTitle_DebugRun(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wantTitle := "Debug proposal title"
|
|
tests := []struct {
|
|
name string
|
|
alwaysEnableDebugLogs bool
|
|
response func() chattest.OpenAIResponse
|
|
wantErr bool
|
|
wantTitle string
|
|
wantTitleGenerationRuns int
|
|
wantDebugStatus codersdk.ChatDebugStatus
|
|
}{
|
|
{
|
|
name: "Enabled",
|
|
alwaysEnableDebugLogs: true,
|
|
response: func() chattest.OpenAIResponse {
|
|
return chattest.OpenAINonStreamingResponse(
|
|
"{\"title\":\"" + wantTitle + "\"}",
|
|
)
|
|
},
|
|
wantTitle: wantTitle,
|
|
wantTitleGenerationRuns: 1,
|
|
wantDebugStatus: codersdk.ChatDebugStatusCompleted,
|
|
},
|
|
{
|
|
name: "Disabled",
|
|
alwaysEnableDebugLogs: false,
|
|
response: func() chattest.OpenAIResponse {
|
|
return chattest.OpenAINonStreamingResponse(
|
|
"{\"title\":\"" + wantTitle + "\"}",
|
|
)
|
|
},
|
|
wantTitle: wantTitle,
|
|
},
|
|
{
|
|
name: "GenerationErrorFinalizesDebugRun",
|
|
alwaysEnableDebugLogs: true,
|
|
response: func() chattest.OpenAIResponse {
|
|
return chattest.OpenAINonStreamingResponse("not json")
|
|
},
|
|
wantErr: true,
|
|
wantTitleGenerationRuns: 1,
|
|
wantDebugStatus: codersdk.ChatDebugStatusError,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
require.False(t, req.Stream)
|
|
return tt.response()
|
|
})
|
|
user, org, model := seedChatDependenciesWithProvider(
|
|
t,
|
|
db,
|
|
"openai",
|
|
openAIURL,
|
|
)
|
|
server := chatd.New(chatd.Config{
|
|
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
AlwaysEnableDebugLogs: tt.alwaysEnableDebugLogs,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "original title",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
message := insertUserTextMessage(
|
|
t,
|
|
db,
|
|
chat.ID,
|
|
user.ID,
|
|
model.ID,
|
|
"summarize debug title generation",
|
|
model.ContextLimit,
|
|
)
|
|
require.NotEqual(t, uuid.Nil, message.ID)
|
|
|
|
gotTitle, err := server.ProposeChatTitle(ctx, chat)
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Equal(t, tt.wantTitle, gotTitle)
|
|
}
|
|
|
|
runs, err := db.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID,
|
|
LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, runs, tt.wantTitleGenerationRuns)
|
|
if tt.wantTitleGenerationRuns > 0 {
|
|
require.Equal(t, string(codersdk.ChatDebugRunKindTitleGeneration), runs[0].Kind)
|
|
require.Equal(t, string(tt.wantDebugStatus), runs[0].Status)
|
|
require.True(t, runs[0].FinishedAt.Valid)
|
|
require.True(t, runs[0].HistoryTipMessageID.Valid)
|
|
require.Equal(t, message.ID, runs[0].HistoryTipMessageID.Int64)
|
|
}
|
|
if !tt.wantErr {
|
|
var usageMessages int
|
|
err = rawDB.QueryRowContext(
|
|
ctx,
|
|
`SELECT count(*) FROM chat_messages WHERE chat_id = $1 AND visibility = 'model' AND deleted = true`,
|
|
chat.ID,
|
|
).Scan(&usageMessages)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, usageMessages)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func seedChatDependencies(
|
|
t *testing.T,
|
|
db database.Store,
|
|
) (database.User, database.Organization, database.ChatModelConfig) {
|
|
t.Helper()
|
|
openAIURL := chattest.OpenAI(t)
|
|
return seedChatDependenciesWithProvider(t, db, "openai", openAIURL)
|
|
}
|
|
|
|
// seedChatDependenciesWithProvider creates a user, organization,
|
|
// chat provider, and model config for the given provider type and
|
|
// base URL.
|
|
func seedChatDependenciesWithProvider(
|
|
t *testing.T,
|
|
db database.Store,
|
|
provider string,
|
|
baseURL string,
|
|
) (database.User, database.Organization, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: provider,
|
|
DisplayName: provider,
|
|
BaseUrl: baseURL,
|
|
})
|
|
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: provider,
|
|
IsDefault: true,
|
|
})
|
|
return user, org, model
|
|
}
|
|
|
|
func seedChatDependenciesWithProviderPolicy(
|
|
t *testing.T,
|
|
db database.Store,
|
|
provider string,
|
|
baseURL string,
|
|
apiKey string,
|
|
centralAPIKeyEnabled bool,
|
|
allowUserAPIKey bool,
|
|
allowCentralAPIKeyFallback bool,
|
|
) (database.User, database.Organization, database.ChatProvider, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
providerConfig := dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: provider,
|
|
DisplayName: provider,
|
|
BaseUrl: baseURL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
}, func(p *database.InsertChatProviderParams) {
|
|
p.APIKey = apiKey
|
|
p.CentralApiKeyEnabled = centralAPIKeyEnabled
|
|
p.AllowUserApiKey = allowUserAPIKey
|
|
p.AllowCentralApiKeyFallback = allowCentralAPIKeyFallback
|
|
})
|
|
|
|
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: provider,
|
|
IsDefault: true,
|
|
})
|
|
|
|
return user, org, providerConfig, model
|
|
}
|
|
|
|
func seedLastTurnSummary(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
chat database.Chat,
|
|
summary string,
|
|
) {
|
|
t.Helper()
|
|
|
|
affected, err := db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: summary, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), affected)
|
|
}
|
|
|
|
func waitForTerminalChatStatusEvent(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
events <-chan codersdk.ChatStreamEvent,
|
|
) codersdk.ChatStatus {
|
|
t.Helper()
|
|
|
|
var terminalStatus codersdk.ChatStatus
|
|
testutil.Eventually(ctx, t, func(context.Context) bool {
|
|
for {
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
return false
|
|
}
|
|
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
|
continue
|
|
}
|
|
if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError {
|
|
terminalStatus = event.Status.Status
|
|
return true
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
return terminalStatus
|
|
}
|
|
|
|
func waitForTerminalChat(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
chatID uuid.UUID,
|
|
) database.Chat {
|
|
t.Helper()
|
|
|
|
var chatResult database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
got, err := db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.IntervalFast)
|
|
|
|
return chatResult
|
|
}
|
|
|
|
func insertChatModelConfigWithCallConfig(
|
|
t *testing.T,
|
|
db database.Store,
|
|
userID uuid.UUID,
|
|
provider string,
|
|
model string,
|
|
callConfig codersdk.ChatModelCallConfig,
|
|
) database.ChatModelConfig {
|
|
t.Helper()
|
|
|
|
options, err := json.Marshal(callConfig)
|
|
require.NoError(t, err)
|
|
|
|
return dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: model,
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
Options: options,
|
|
})
|
|
}
|
|
|
|
func insertUserTextMessage(
|
|
t *testing.T,
|
|
db database.Store,
|
|
chatID uuid.UUID,
|
|
userID uuid.UUID,
|
|
modelConfigID uuid.UUID,
|
|
text string,
|
|
contextLimit ...int64,
|
|
) database.ChatMessage {
|
|
t.Helper()
|
|
require.LessOrEqual(t, len(contextLimit), 1)
|
|
|
|
contextLimitValue := int64(0)
|
|
if len(contextLimit) == 1 {
|
|
contextLimitValue = contextLimit[0]
|
|
}
|
|
content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)})
|
|
require.NoError(t, err)
|
|
|
|
return dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chatID,
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
|
Role: database.ChatMessageRoleUser,
|
|
Content: pqtype.NullRawMessage{RawMessage: content.RawMessage, Valid: true},
|
|
ContextLimit: sql.NullInt64{Int64: contextLimitValue, Valid: contextLimitValue != 0},
|
|
})
|
|
}
|
|
|
|
// seedWorkspaceWithAgent creates a full workspace chain with a connected
|
|
// agent. This is the common setup needed by tests that exercise tool
|
|
// execution against a workspace.
|
|
func seedWorkspaceWithAgent(
|
|
t *testing.T,
|
|
db database.Store,
|
|
userID uuid.UUID,
|
|
) (database.WorkspaceTable, database.WorkspaceAgent) {
|
|
t.Helper()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: userID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: userID,
|
|
OrganizationID: org.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
TemplateID: tpl.ID,
|
|
OwnerID: userID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
InitiatorID: userID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: ws.ID,
|
|
JobID: pj.ID,
|
|
})
|
|
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
Transition: database.WorkspaceTransitionStart,
|
|
JobID: pj.ID,
|
|
})
|
|
dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: res.ID,
|
|
})
|
|
return ws, dbAgent
|
|
}
|
|
|
|
func setOpenAIProviderBaseURL(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
baseURL string,
|
|
) {
|
|
t.Helper()
|
|
|
|
providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true})
|
|
require.NoError(t, err)
|
|
for _, provider := range providers {
|
|
if provider.Type != database.AiProviderTypeOpenai {
|
|
continue
|
|
}
|
|
_, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{
|
|
ID: provider.ID,
|
|
DisplayName: provider.DisplayName,
|
|
Enabled: provider.Enabled,
|
|
BaseUrl: baseURL,
|
|
Settings: provider.Settings,
|
|
SettingsKeyID: provider.SettingsKeyID,
|
|
})
|
|
require.NoError(t, err)
|
|
return
|
|
}
|
|
require.Fail(t, "openai provider not found")
|
|
}
|
|
|
|
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that blocks until the request context is
|
|
// canceled (i.e. until the chat is interrupted).
|
|
streamStarted := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
// Block until the chat context is canceled by the interrupt.
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
// Mock webpush dispatcher that records calls.
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-no-push",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
seedLastTurnSummary(ctx, t, db, chat, "previous summary")
|
|
|
|
server.Start()
|
|
|
|
// Wait for the chat to be picked up and start streaming.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
// Interrupt the chat.
|
|
updated := server.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
|
|
// Wait for the chat to finish processing and return to waiting.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fromDB.LastTurnSummary.Valid,
|
|
"interrupted chats should clear cached turn summaries")
|
|
|
|
// Verify no web push notification was dispatched.
|
|
require.Equal(t, int32(0), mockPush.dispatchCount.Load(),
|
|
"expected no web push dispatch for an interrupted chat")
|
|
}
|
|
|
|
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
|
|
type mockWebpushDispatcher struct {
|
|
dispatchCount atomic.Int32
|
|
mu sync.Mutex
|
|
lastMessage codersdk.WebpushMessage
|
|
lastUserID uuid.UUID
|
|
}
|
|
|
|
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
|
m.dispatchCount.Add(1)
|
|
m.mu.Lock()
|
|
m.lastMessage = msg
|
|
m.lastUserID = userID
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (m *mockWebpushDispatcher) getLastMessage() codersdk.WebpushMessage {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.lastMessage
|
|
}
|
|
|
|
func (*mockWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
|
|
return nil
|
|
}
|
|
|
|
func (*mockWebpushDispatcher) PublicKey() string {
|
|
return "test-vapid-public-key"
|
|
}
|
|
|
|
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that returns a simple successful response.
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
// Mock webpush dispatcher that captures the dispatched message.
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "push-nav-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to complete and return to waiting status.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
|
|
}, testutil.IntervalFast)
|
|
|
|
// Verify a web push notification was dispatched exactly once.
|
|
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
|
"expected exactly one web push dispatch for a completed chat")
|
|
|
|
// Verify the notification was sent to the correct user.
|
|
mockPush.mu.Lock()
|
|
capturedMsg := mockPush.lastMessage
|
|
capturedUserID := mockPush.lastUserID
|
|
mockPush.mu.Unlock()
|
|
|
|
require.Equal(t, user.ID, capturedUserID,
|
|
"web push should be dispatched to the chat owner")
|
|
|
|
// Verify the Data field contains the correct navigation URL.
|
|
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
|
|
require.Equal(t, expectedURL, capturedMsg.Data["url"],
|
|
"web push Data should contain the chat navigation URL")
|
|
}
|
|
|
|
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var requestCount atomic.Int32
|
|
streamStarted := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
// Ignore non-streaming requests (e.g. title generation) so
|
|
// they don't interfere with the request counter used to
|
|
// coordinate the streaming chat flow.
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("shutdown-retry")
|
|
}
|
|
if requestCount.Add(1) == 1 {
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...)
|
|
})
|
|
|
|
loggerA := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverA := chatd.New(chatd.Config{
|
|
Logger: loggerA,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
})
|
|
serverA.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverA.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := serverA.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "shutdown-retry",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.NoError(t, serverA.Close())
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
loggerB := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverB := chatd.New(chatd.Config{
|
|
Logger: loggerB,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
})
|
|
serverB.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverB.Close())
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
return requestCount.Load() >= 2
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const assistantText = "I have completed the task successfully and all tests are passing now."
|
|
const summaryText = "Finished unit tests"
|
|
|
|
var nonStreamingRequests atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
if strings.Contains(string(req.RawBody), "propose_turn_status_label") {
|
|
nonStreamingRequests.Add(1)
|
|
return chattest.OpenAINonStreamingResponse(fmt.Sprintf(`{"label":%q}`, summaryText))
|
|
}
|
|
return chattest.OpenAINonStreamingResponse(`{"title":"Summary push test"}`)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(assistantText)...,
|
|
)
|
|
})
|
|
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "summary-push-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The push notification is dispatched asynchronously after the
|
|
// chat finishes, so we poll for it rather than checking
|
|
// immediately after the status transitions to waiting.
|
|
var fromDB database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
var dbErr error
|
|
fromDB, dbErr = db.GetChatByID(ctx, chat.ID)
|
|
return dbErr == nil && mockPush.dispatchCount.Load() >= 1 && fromDB.LastTurnSummary.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
msg := mockPush.getLastMessage()
|
|
require.Equal(t, summaryText, fromDB.LastTurnSummary.String,
|
|
"last turn summary should be the LLM-generated status label")
|
|
require.Equal(t, fromDB.LastTurnSummary.String, msg.Body,
|
|
"push body should reuse the persisted generated status label")
|
|
require.Equal(t, int32(1), nonStreamingRequests.Load(),
|
|
"expected exactly one non-streaming request for status label generation")
|
|
}
|
|
|
|
func TestSuccessfulChatPersistsTurnSummaryWithoutWebPush(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const assistantText = "I fixed the bug and added regression coverage."
|
|
const summaryText = "Fixed regression bug"
|
|
|
|
var nonStreamingRequests atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
if strings.Contains(string(req.RawBody), "propose_turn_status_label") {
|
|
nonStreamingRequests.Add(1)
|
|
return chattest.OpenAINonStreamingResponse(fmt.Sprintf(`{"label":%q}`, summaryText))
|
|
}
|
|
return chattest.OpenAINonStreamingResponse(`{"title":"Summary push test"}`)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(assistantText)...,
|
|
)
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "summary-no-webpush-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var fromDB database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
var dbErr error
|
|
fromDB, dbErr = db.GetChatByID(ctx, chat.ID)
|
|
return dbErr == nil && fromDB.LastTurnSummary.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
require.Equal(t, summaryText, fromDB.LastTurnSummary.String,
|
|
"status label should persist even when web push is unavailable")
|
|
require.Equal(t, int32(1), nonStreamingRequests.Load(),
|
|
"expected exactly one non-streaming request for status label generation")
|
|
}
|
|
|
|
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var nonStreamingRequests atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
if strings.Contains(string(req.RawBody), "propose_turn_status_label") {
|
|
nonStreamingRequests.Add(1)
|
|
return chattest.OpenAINonStreamingResponse(`{"label":"Unexpected label"}`)
|
|
}
|
|
return chattest.OpenAINonStreamingResponse(`{"title":"Empty summary push test"}`)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(" ")...,
|
|
)
|
|
})
|
|
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "empty-summary-push-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
|
})
|
|
require.NoError(t, err)
|
|
seedLastTurnSummary(ctx, t, db, chat, "previous summary")
|
|
|
|
server.Start()
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
return mockPush.dispatchCount.Load() >= 1
|
|
}, testutil.IntervalFast)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sql.NullString{String: "Finished latest turn", Valid: true}, fromDB.LastTurnSummary,
|
|
"fallback status label should be persisted")
|
|
|
|
msg := mockPush.getLastMessage()
|
|
require.Equal(t, "Finished latest turn", msg.Body,
|
|
"push body should fall back when the final assistant text is empty")
|
|
require.Equal(t, int32(0), nonStreamingRequests.Load(),
|
|
"status label model should not run when final assistant text has no usable text")
|
|
}
|
|
|
|
func TestErroredChatClearsLastTurnSummaryAndSendsWebPush(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIErrorResponse(http.StatusBadRequest, "invalid_request_error", "Bad request")
|
|
})
|
|
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "error-summary-clear-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
|
})
|
|
require.NoError(t, err)
|
|
seedLastTurnSummary(ctx, t, db, chat, "previous summary")
|
|
|
|
server.Start()
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
return dbErr == nil &&
|
|
fromDB.Status == database.ChatStatusError &&
|
|
mockPush.dispatchCount.Load() >= 1
|
|
}, testutil.IntervalFast)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fromDB.LastTurnSummary.Valid,
|
|
"errored chats should clear cached turn summaries")
|
|
|
|
msg := mockPush.getLastMessage()
|
|
require.NotEqual(t, "Hit an error", msg.Body)
|
|
require.Contains(t, msg.Body, "OpenAI returned an unexpected error")
|
|
}
|
|
|
|
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
computerUseModelProvider, computerUseModelName, ok := chattool.DefaultComputerUseModel(chattool.ComputerUseProviderAnthropic)
|
|
require.True(t, ok)
|
|
require.Equal(t, chattool.ComputerUseProviderAnthropic, computerUseModelProvider)
|
|
|
|
// Track tools and model from the Anthropic LLM calls (the
|
|
// computer use child chat). We use a raw HTTP handler because
|
|
// the chattest AnthropicRequest struct does not capture tools.
|
|
type anthropicCall struct {
|
|
Model string
|
|
Tools []string
|
|
Stream bool
|
|
}
|
|
var anthropicMu sync.Mutex
|
|
var anthropicCalls []anthropicCall
|
|
|
|
anthropicSrv := httptest.NewServer(http.HandlerFunc(
|
|
func(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Model string `json:"model"`
|
|
Stream bool `json:"stream"`
|
|
Tools []struct {
|
|
Name string `json:"name"`
|
|
} `json:"tools"`
|
|
}
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
names := make([]string, len(req.Tools))
|
|
for i, tool := range req.Tools {
|
|
names[i] = tool.Name
|
|
}
|
|
anthropicMu.Lock()
|
|
anthropicCalls = append(anthropicCalls, anthropicCall{
|
|
Model: req.Model,
|
|
Tools: names,
|
|
Stream: req.Stream,
|
|
})
|
|
anthropicMu.Unlock()
|
|
|
|
if !req.Stream {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"id": "msg-test",
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": computerUseModelName,
|
|
"content": []map[string]any{{"type": "text", "text": "Done."}},
|
|
"stop_reason": "end_turn",
|
|
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Stream a minimal Anthropic SSE response.
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
flusher, _ := w.(http.Flusher)
|
|
|
|
chunks := []map[string]any{
|
|
{
|
|
"type": "message_start",
|
|
"message": map[string]any{
|
|
"id": "msg-test",
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": computerUseModelName,
|
|
},
|
|
},
|
|
{
|
|
"type": "content_block_start",
|
|
"index": 0,
|
|
"content_block": map[string]any{
|
|
"type": "text",
|
|
"text": "",
|
|
},
|
|
},
|
|
{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": map[string]any{
|
|
"type": "text_delta",
|
|
"text": "Done.",
|
|
},
|
|
},
|
|
{"type": "content_block_stop", "index": 0},
|
|
{
|
|
"type": "message_delta",
|
|
"delta": map[string]any{"stop_reason": "end_turn"},
|
|
"usage": map[string]any{"output_tokens": 5},
|
|
},
|
|
{"type": "message_stop"},
|
|
}
|
|
|
|
for _, chunk := range chunks {
|
|
chunkBytes, _ := json.Marshal(chunk)
|
|
eventType, _ := chunk["type"].(string)
|
|
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n",
|
|
eventType, chunkBytes)
|
|
flusher.Flush()
|
|
}
|
|
},
|
|
))
|
|
t.Cleanup(anthropicSrv.Close)
|
|
|
|
// OpenAI mock for the root chat. The first streaming call
|
|
// triggers spawn_agent; subsequent calls reply
|
|
// with text.
|
|
var openAICallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
if openAICallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"spawn_agent",
|
|
`{"type":"computer_use","prompt":"do the desktop thing","title":"cu-sub"}`,
|
|
),
|
|
)
|
|
}
|
|
// Include literal \u0000 in the response text, which is
|
|
// what a real LLM writes when explaining binary output.
|
|
// json.Marshal encodes the backslash as \\, producing
|
|
// \\u0000 in the JSON bytes. The sanitizer must not
|
|
// corrupt this into invalid JSON.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
|
|
)
|
|
})
|
|
|
|
// Seed the DB: user, openai-compat provider, model config.
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Add an Anthropic provider pointing to our mock server.
|
|
dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: "anthropic",
|
|
DisplayName: "Anthropic",
|
|
APIKey: "test-anthropic-key",
|
|
BaseUrl: anthropicSrv.URL,
|
|
})
|
|
|
|
err := db.UpsertChatDesktopEnabled(ctx, true)
|
|
require.NoError(t, err)
|
|
|
|
// Build workspace + agent records so getWorkspaceConn can
|
|
// resolve the agent for the computer use child.
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
// Mock agent connection that returns valid display dimensions
|
|
// for the initial screenshot check in the computer use path.
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().
|
|
ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
ExecuteDesktopAction(gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.DesktopActionResponse{
|
|
ScreenshotWidth: 1920,
|
|
ScreenshotHeight: 1080,
|
|
ScreenshotData: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4n539HwAHFwLVF8kc1wAAAABJRU5ErkJggg==",
|
|
}, nil).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
SetExtraHeaders(gomock.Any()).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
|
|
AnyTimes()
|
|
mockConn.EXPECT().
|
|
LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, xerrors.New("not found")).
|
|
AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
// Create a root chat with a workspace so the child inherits it.
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "computer-use-detection",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Use the desktop to check the UI"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the root chat AND the computer use child to finish.
|
|
// The root chat spawns the child, then the chatd server picks
|
|
// up and runs the child (which hits the Anthropic mock).
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Status != database.ChatStatusWaiting &&
|
|
got.Status != database.ChatStatusError {
|
|
return false
|
|
}
|
|
// Ensure the Anthropic mock received the child streaming call.
|
|
anthropicMu.Lock()
|
|
defer anthropicMu.Unlock()
|
|
for _, call := range anthropicCalls {
|
|
if call.Stream {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
anthropicMu.Lock()
|
|
calls := append([]anthropicCall(nil), anthropicCalls...)
|
|
anthropicMu.Unlock()
|
|
|
|
require.NotEmpty(t, calls,
|
|
"expected at least one Anthropic LLM call")
|
|
|
|
var childCall anthropicCall
|
|
for _, call := range calls {
|
|
if call.Stream {
|
|
childCall = call
|
|
break
|
|
}
|
|
}
|
|
require.True(t, childCall.Stream,
|
|
"expected at least one streaming Anthropic child LLM call")
|
|
|
|
childModel := childCall.Model
|
|
childTools := childCall.Tools
|
|
|
|
// 1. Verify the model is the computer use model.
|
|
require.Equal(t, computerUseModelName, childModel,
|
|
"computer use subagent should use %s",
|
|
computerUseModelName)
|
|
|
|
// 2. Verify the computer tool is present.
|
|
require.Contains(t, childTools, "computer",
|
|
"computer use subagent should have the computer tool")
|
|
|
|
// 3. Verify standard workspace tools are present (the same
|
|
// set a regular subagent gets).
|
|
standardTools := []string{
|
|
"read_file", "write_file", "edit_files", "execute",
|
|
"process_output", "process_list", "process_signal",
|
|
}
|
|
for _, tool := range standardTools {
|
|
require.Contains(t, childTools, tool,
|
|
"computer use subagent should have standard tool %q",
|
|
tool)
|
|
}
|
|
|
|
// 4. Verify workspace provisioning tools are NOT present.
|
|
workspaceProvisioningTools := []string{
|
|
"list_templates", "read_template",
|
|
"create_workspace", "start_workspace", "stop_workspace",
|
|
}
|
|
for _, tool := range workspaceProvisioningTools {
|
|
require.NotContains(t, childTools, tool,
|
|
"computer use subagent should NOT have workspace "+
|
|
"provisioning tool %q", tool)
|
|
}
|
|
|
|
// 5. Verify subagent tools are NOT present.
|
|
subagentTools := []string{
|
|
"spawn_agent",
|
|
"wait_agent", "message_agent", "close_agent",
|
|
}
|
|
for _, tool := range subagentTools {
|
|
require.NotContains(t, childTools, tool,
|
|
"computer use subagent should NOT have subagent "+
|
|
"tool %q", tool)
|
|
}
|
|
|
|
// 6. Verify the child chat has Mode = computer_use in
|
|
// the DB.
|
|
childRows, err := db.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
|
|
ParentIds: []uuid.UUID{chat.ID},
|
|
})
|
|
require.NoError(t, err)
|
|
children := make([]database.Chat, 0, len(childRows))
|
|
for _, row := range childRows {
|
|
children = append(children, row.Chat)
|
|
}
|
|
require.Len(t, children, 1)
|
|
require.True(t, children[0].Mode.Valid)
|
|
require.Equal(t, database.ChatModeComputerUse,
|
|
children[0].Mode.ChatMode)
|
|
}
|
|
|
|
func TestInterruptChatPersistsPartialResponse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that streams a partial response and then
|
|
// blocks until the request context is canceled (simulating an
|
|
// interrupt mid-stream).
|
|
chunksDelivered := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
// Send two partial text chunks so there is meaningful
|
|
// content to persist.
|
|
for _, c := range chattest.OpenAITextChunks("hello world") {
|
|
chunks <- c
|
|
}
|
|
// Signal that chunks have been written to the HTTP response.
|
|
select {
|
|
case <-chunksDelivered:
|
|
default:
|
|
close(chunksDelivered)
|
|
}
|
|
// Block until interrupt cancels the context.
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
})
|
|
server.Start()
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-persist-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Subscribe to the chat's event stream so we can observe
|
|
// message_part events. This proves the chatloop has actually
|
|
// processed the streamed chunks.
|
|
_, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
defer subCancel()
|
|
|
|
// Wait for the mock to finish sending chunks.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-chunksDelivered:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
// Drain the event channel until we see a message_part event,
|
|
// which means the chatloop has consumed and published the chunk.
|
|
gotMessagePart := false
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
for {
|
|
select {
|
|
case ev := <-events:
|
|
if ev.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
gotMessagePart = true
|
|
return true
|
|
}
|
|
default:
|
|
return gotMessagePart
|
|
}
|
|
}
|
|
}, testutil.IntervalFast)
|
|
require.True(t, gotMessagePart, "should have received at least one message_part event")
|
|
|
|
// Now interrupt the chat. The chatloop has processed content.
|
|
updated := server.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
|
|
// Wait for the partial assistant message to be persisted.
|
|
// After the interrupt, the chatloop runs persistInterruptedStep
|
|
// which inserts the message and publishes a "message" event.
|
|
// We poll the DB directly for the assistant message rather than
|
|
// relying on the chat status (which transitions to "waiting"
|
|
// before the persist completes).
|
|
var assistantMsg *database.ChatMessage
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
msgs, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for i := range msgs {
|
|
if msgs[i].Role == database.ChatMessageRoleAssistant {
|
|
assistantMsg = &msgs[i]
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.NotNilf(t, assistantMsg, "expected a persisted assistant message after interrupt")
|
|
|
|
// Parse the content and verify it contains the partial text.
|
|
parts, err := chatprompt.ParseContent(*assistantMsg)
|
|
require.NoError(t, err)
|
|
|
|
var foundText string
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText {
|
|
foundText += part.Text
|
|
}
|
|
}
|
|
require.Contains(t, foundText, "hello world",
|
|
"partial assistant response should contain the streamed text")
|
|
}
|
|
|
|
func TestProcessChat_UserProviderKey_Success(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const userAPIKey = "user-test-key"
|
|
|
|
var authHeadersMu sync.Mutex
|
|
authHeaders := make([]string, 0, 1)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
authHeadersMu.Lock()
|
|
authHeaders = append(authHeaders, req.Header.Get("Authorization"))
|
|
authHeadersMu.Unlock()
|
|
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("user provider key success")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("hello from the saved user key")...,
|
|
)
|
|
})
|
|
|
|
user, org, provider, model := seedChatDependenciesWithProviderPolicy(
|
|
t,
|
|
db,
|
|
"openai-compat",
|
|
openAIURL,
|
|
"",
|
|
false,
|
|
true,
|
|
false,
|
|
)
|
|
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
|
ID: uuid.New(),
|
|
UserID: user.ID,
|
|
AIProviderID: provider.ID,
|
|
APIKey: userAPIKey,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
creator := newTestServer(t, db, ps, uuid.New())
|
|
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "user-provider-key-success",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("say hello"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
_ = newActiveTestServer(t, db, ps)
|
|
|
|
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
|
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
require.False(t, chatResult.LastError.Valid)
|
|
|
|
authHeadersMu.Lock()
|
|
recordedAuthHeaders := append([]string(nil), authHeaders...)
|
|
authHeadersMu.Unlock()
|
|
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
|
|
}
|
|
|
|
func TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if req.Stream {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("hello through AI Gateway")...,
|
|
)
|
|
}
|
|
return chattest.OpenAINonStreamingResponse(`{"title":"AI Gateway Chat"}`)
|
|
})
|
|
factory := newChatAIGatewayTestFactory(t, openAIURL)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
|
Type: database.AiProviderTypeOpenai,
|
|
Name: "primary-openai-" + uuid.NewString(),
|
|
BaseUrl: openAIURL,
|
|
})
|
|
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: string(database.AiProviderTypeOpenai),
|
|
Model: "gpt-4o-mini",
|
|
IsDefault: true,
|
|
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
|
})
|
|
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
|
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
|
ID: uuid.New(),
|
|
UserID: user.ID,
|
|
AIProviderID: provider.ID,
|
|
APIKey: "sk-user-aibridge",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
creator := newTestServer(t, db, ps, uuid.New())
|
|
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "aigateway-routing",
|
|
ModelConfigID: model.ID,
|
|
APIKeyID: apiKey.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("say hello"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AIBridgeTransportFactory = chatAIGatewayTransportFactoryPointer(factory)
|
|
cfg.AIGatewayRoutingEnabled = true
|
|
cfg.AllowBYOK = true
|
|
cfg.AllowBYOKSet = true
|
|
})
|
|
|
|
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
|
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
require.False(t, chatResult.LastError.Valid)
|
|
|
|
requests := factory.requestsSnapshot()
|
|
require.NotEmpty(t, requests)
|
|
require.Contains(t, requests, chatAIGatewayRecordedRequest{
|
|
ProviderName: provider.Name,
|
|
Source: aibridge.SourceAgents,
|
|
APIKeyID: apiKey.ID,
|
|
Path: "/v1/responses",
|
|
Authorization: "Bearer sk-user-aibridge",
|
|
CoderToken: "delegated",
|
|
})
|
|
for _, req := range requests {
|
|
require.Equal(t, provider.Name, req.ProviderName)
|
|
require.Equal(t, aibridge.SourceAgents, req.Source)
|
|
require.Equal(t, apiKey.ID, req.APIKeyID)
|
|
require.Equal(t, "Bearer sk-user-aibridge", req.Authorization)
|
|
require.Empty(t, req.XAPIKey)
|
|
require.Equal(t, "delegated", req.CoderToken)
|
|
require.True(t, strings.HasPrefix(req.Path, "/v1/"), "unexpected aibridge path %q", req.Path)
|
|
}
|
|
}
|
|
|
|
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var llmCalls atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
llmCalls.Add(1)
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("unexpected non-streaming request")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("unexpected streaming request")...,
|
|
)
|
|
})
|
|
|
|
user, org, _, model := seedChatDependenciesWithProviderPolicy(
|
|
t,
|
|
db,
|
|
"openai-compat",
|
|
openAIURL,
|
|
"",
|
|
false,
|
|
true,
|
|
false,
|
|
)
|
|
|
|
creator := newTestServer(t, db, ps, uuid.New())
|
|
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "user-provider-key-missing",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("say hello"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
_ = newActiveTestServer(t, db, ps)
|
|
|
|
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
|
require.Equal(t, codersdk.ChatStatusError, terminalStatus)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
require.Equal(t, database.ChatStatusError, chatResult.Status)
|
|
persistedError := requireChatLastErrorPayload(t, chatResult.LastError)
|
|
require.NotEmpty(t, persistedError.Message)
|
|
require.NotContains(t, persistedError.Message, "panicked")
|
|
require.Equal(t, codersdk.ChatErrorKindGeneric, persistedError.Kind)
|
|
require.NotEqual(t, database.ChatStatusRunning, chatResult.Status)
|
|
require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request")
|
|
}
|
|
|
|
func TestProcessChatPanicRecovery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
// Wrap the database so we can trigger a panic on the main
|
|
// goroutine of processChat. The chatloop's executeTools has
|
|
// its own recover, so panicking inside a tool goroutine won't
|
|
// reach the processChat-level recovery. Instead, we panic
|
|
// during PersistStep's InTx call, which runs synchronously on
|
|
// the processChat goroutine.
|
|
panicWrapper := &panicOnInTxDB{Store: db}
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Panic recovery test")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("hello")...,
|
|
)
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Pass the panic wrapper to the server, but use the real
|
|
// database for seeding so those operations don't panic.
|
|
server := newActiveTestServer(t, panicWrapper, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "panic-recovery",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("hello"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Enable the panic now that CreateChat's InTx has completed.
|
|
// The next InTx call is PersistStep inside the chatloop,
|
|
// running synchronously on the processChat goroutine.
|
|
panicWrapper.enablePanic()
|
|
|
|
// Wait for the panic to be recovered and the chat to
|
|
// transition to error status.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
persistedError := requireChatLastErrorPayload(t, chatResult.LastError)
|
|
require.Contains(t, persistedError.Message, "chat processing panicked")
|
|
require.Contains(t, persistedError.Message, "intentional test panic")
|
|
require.Equal(t, codersdk.ChatErrorKindGeneric, persistedError.Kind)
|
|
}
|
|
|
|
// panicOnInTxDB wraps a database.Store and panics on the first InTx
|
|
// call after enablePanic is called. Subsequent calls pass through
|
|
// so the processChat cleanup defer can update the chat status.
|
|
type panicOnInTxDB struct {
|
|
database.Store
|
|
active atomic.Bool
|
|
panicked atomic.Bool
|
|
}
|
|
|
|
func (d *panicOnInTxDB) enablePanic() { d.active.Store(true) }
|
|
|
|
func (d *panicOnInTxDB) InTx(f func(database.Store) error, opts *database.TxOptions) error {
|
|
if d.active.Load() && !d.panicked.Load() {
|
|
d.panicked.Store(true)
|
|
panic("intentional test panic")
|
|
}
|
|
return d.Store.InTx(f, opts)
|
|
}
|
|
|
|
// TestMCPServerToolInvocation verifies that when a chat has
|
|
// mcp_server_ids set, the chat loop connects to those MCP servers,
|
|
// discovers their tools, and the LLM can invoke them.
|
|
//
|
|
// NOTE: This test uses a raw database.Store (no dbauthz wrapper).
|
|
// The chatd RBAC authorization of GetMCPServerConfigsByIDs (which
|
|
// requires ActionRead on ResourceDeploymentConfig) is covered by
|
|
// the chatd role definition tests, not here.
|
|
func TestMCPServerToolInvocation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Start a real MCP server that exposes an "echo" tool.
|
|
mcpSrv := mcpserver.NewMCPServer("test-mcp", "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpHTTP := mcpserver.NewStreamableHTTPServer(mcpSrv)
|
|
mcpTS := httptest.NewServer(mcpHTTP)
|
|
t.Cleanup(mcpTS.Close)
|
|
|
|
// Track which tool names are sent to the LLM and capture
|
|
// whether the MCP tool result appears in the second call.
|
|
var (
|
|
callCount atomic.Int32
|
|
llmToolNames []string
|
|
llmToolsMu sync.Mutex
|
|
foundMCPResult atomic.Bool
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
// Record tool names from the first streamed call.
|
|
if callCount.Add(1) == 1 {
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
llmToolsMu.Lock()
|
|
llmToolNames = names
|
|
llmToolsMu.Unlock()
|
|
|
|
// Ask the LLM to call the MCP echo tool.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"test-mcp__echo",
|
|
`{"input":"hello from LLM"}`,
|
|
),
|
|
)
|
|
}
|
|
|
|
// Second call: verify the tool result was fed back.
|
|
for _, msg := range req.Messages {
|
|
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: hello from LLM") {
|
|
foundMCPResult.Store(true)
|
|
}
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Got it!")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Seed the MCP server config in the database. This must
|
|
// happen after seedChatDependencies so user.ID exists for
|
|
// the foreign key.
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Test MCP",
|
|
Slug: "test-mcp",
|
|
Url: mcpTS.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "mcp-tool-test",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Echo something via MCP."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify MCPServerIDs were persisted on the chat record.
|
|
dbChat, getErr := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, getErr)
|
|
require.Equal(t, []uuid.UUID{mcpConfig.ID}, dbChat.MCPServerIDs)
|
|
|
|
// Wait for the chat to finish processing.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// The MCP tool (test-mcp__echo) should appear in the tool
|
|
// list sent to the LLM.
|
|
llmToolsMu.Lock()
|
|
recordedNames := append([]string(nil), llmToolNames...)
|
|
llmToolsMu.Unlock()
|
|
require.Contains(t, recordedNames, "test-mcp__echo",
|
|
"MCP tool should be in the tool list sent to the LLM")
|
|
|
|
// The tool result from the MCP server ("echo: hello from
|
|
// LLM") should have been fed back to the LLM as a tool
|
|
// message in the second call.
|
|
require.True(t, foundMCPResult.Load(),
|
|
"MCP tool result should appear in the second LLM call")
|
|
|
|
// Verify the tool result was persisted in the database.
|
|
var foundToolMessage bool
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil || len(parts) == 0 {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
|
|
part.ToolName == "test-mcp__echo" &&
|
|
strings.Contains(string(part.Result), "echo: hello from LLM") {
|
|
foundToolMessage = true
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.True(t, foundToolMessage,
|
|
"MCP tool result should be persisted as a tool message in the database")
|
|
}
|
|
|
|
func TestPlanModeRootChatApprovedExternalMCPToolInvocation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
mcpSrv := mcpserver.NewMCPServer("plan-mode-mcp", "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(mcpSrv))
|
|
t.Cleanup(mcpTS.Close)
|
|
|
|
var (
|
|
callCount atomic.Int32
|
|
llmToolNames []string
|
|
llmToolsMu sync.Mutex
|
|
foundMCPResult atomic.Bool
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
if callCount.Add(1) == 1 {
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
llmToolsMu.Lock()
|
|
llmToolNames = names
|
|
llmToolsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"plan-mode-mcp__echo",
|
|
`{"input":"hello from root plan mode"}`,
|
|
),
|
|
)
|
|
}
|
|
|
|
for _, msg := range req.Messages {
|
|
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: hello from root plan mode") {
|
|
foundMCPResult.Store(true)
|
|
}
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Planning complete.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Plan Mode MCP",
|
|
Slug: "plan-mode-mcp",
|
|
Url: mcpTS.URL,
|
|
AllowInPlanMode: true,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "plan-mode-mcp-invocation",
|
|
ModelConfigID: model.ID,
|
|
PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Use the approved MCP tool while planning."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, chat.ID, server)
|
|
|
|
chatResult, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
|
|
llmToolsMu.Lock()
|
|
recordedNames := append([]string(nil), llmToolNames...)
|
|
llmToolsMu.Unlock()
|
|
require.Contains(t, recordedNames, "plan-mode-mcp__echo",
|
|
"approved external MCP tools should be available in root plan mode")
|
|
require.True(t, foundMCPResult.Load(),
|
|
"approved external MCP tool results should feed back into the follow-up plan-mode turn")
|
|
}
|
|
|
|
func TestPlanModeRootChatApprovedExternalMCPWorkflowCanReachProposePlan(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
mcpSrv := mcpserver.NewMCPServer("plan-workflow-mcp", "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpTS := httptest.NewServer(mcpserver.NewStreamableHTTPServer(mcpSrv))
|
|
t.Cleanup(mcpTS.Close)
|
|
|
|
var (
|
|
callCount atomic.Int32
|
|
llmToolNames []string
|
|
llmToolsMu sync.Mutex
|
|
sawMCPResult atomic.Bool
|
|
proposePlanReached atomic.Bool
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch callCount.Add(1) {
|
|
case 1:
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
llmToolsMu.Lock()
|
|
llmToolNames = names
|
|
llmToolsMu.Unlock()
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"plan-workflow-mcp__echo",
|
|
`{"input":"prepare the plan"}`,
|
|
),
|
|
)
|
|
case 2:
|
|
for _, msg := range req.Messages {
|
|
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: prepare the plan") {
|
|
sawMCPResult.Store(true)
|
|
}
|
|
}
|
|
proposePlanReached.Store(true)
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("propose_plan", `{}`),
|
|
)
|
|
default:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("should not continue")...,
|
|
)
|
|
}
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Plan Workflow MCP",
|
|
Slug: "plan-workflow-mcp",
|
|
Url: mcpTS.URL,
|
|
AllowInPlanMode: true,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, path string, _, _ int64) (io.ReadCloser, string, error) {
|
|
if strings.HasSuffix(path, ".md") {
|
|
return io.NopCloser(strings.NewReader("# Plan\n- Use the approved MCP tool findings.\n")), "", nil
|
|
}
|
|
return io.NopCloser(strings.NewReader("")), "", nil
|
|
}).AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "plan-mode-mcp-propose-plan",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Use the approved MCP tool, then propose the plan."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
waitForChatProcessed(ctx, t, db, chat.ID, server)
|
|
|
|
chatResult, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
|
|
llmToolsMu.Lock()
|
|
recordedNames := append([]string(nil), llmToolNames...)
|
|
llmToolsMu.Unlock()
|
|
require.Contains(t, recordedNames, "plan-workflow-mcp__echo",
|
|
"approved external MCP tools should be available in the root plan-mode workflow")
|
|
require.True(t, sawMCPResult.Load(),
|
|
"the root plan-mode workflow should feed the approved MCP result into the propose_plan turn")
|
|
require.True(t, proposePlanReached.Load(),
|
|
"the root plan-mode workflow should reach propose_plan after using the approved MCP tool")
|
|
require.Equal(t, int32(2), callCount.Load(),
|
|
"the workflow should stop immediately after propose_plan succeeds")
|
|
|
|
var foundProposePlanResult bool
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "propose_plan" {
|
|
foundProposePlanResult = true
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.True(t, foundProposePlanResult,
|
|
"the root plan-mode workflow should persist a propose_plan tool result")
|
|
}
|
|
|
|
// TestMCPServerOAuth2TokenRefresh verifies that when a chat uses an
|
|
// MCP server with OAuth2 auth and the stored access token is expired,
|
|
// chatd refreshes the token using the stored refresh_token before
|
|
// connecting. The refreshed token is persisted to the database and
|
|
// the MCP tool call succeeds.
|
|
func TestMCPServerOAuth2TokenRefresh(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// The "fresh" token that the mock OAuth2 server returns after
|
|
// a successful refresh_token grant.
|
|
freshAccessToken := "fresh-access-token-" + uuid.New().String()
|
|
|
|
// Mock OAuth2 token endpoint that exchanges a refresh token
|
|
// for a new access token.
|
|
var refreshCalled atomic.Int32
|
|
tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
refreshCalled.Add(1)
|
|
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
grantType := r.FormValue("grant_type")
|
|
if grantType != "refresh_token" {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = w.Write([]byte(`{"error":"unsupported_grant_type"}`))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = fmt.Fprintf(w, `{"access_token":%q,"token_type":"Bearer","expires_in":3600,"refresh_token":"rotated-refresh-token"}`, freshAccessToken)
|
|
}))
|
|
t.Cleanup(tokenSrv.Close)
|
|
|
|
// Start a real MCP server with an auth middleware that only
|
|
// accepts the fresh access token. An expired token (or any
|
|
// other value) gets a 401.
|
|
mcpSrv := mcpserver.NewMCPServer("authed-mcp", "1.0.0")
|
|
mcpSrv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcpgo.NewTool("echo",
|
|
mcpgo.WithDescription("Echoes the input"),
|
|
mcpgo.WithString("input",
|
|
mcpgo.Description("The input string"),
|
|
mcpgo.Required(),
|
|
),
|
|
),
|
|
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
|
|
input, _ := req.GetArguments()["input"].(string)
|
|
return mcpgo.NewToolResultText("echo: " + input), nil
|
|
},
|
|
})
|
|
mcpHTTP := mcpserver.NewStreamableHTTPServer(mcpSrv)
|
|
// Wrap with auth check.
|
|
authMux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
if auth != "Bearer "+freshAccessToken {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = w.Write([]byte(`{"error":"invalid_token","error_description":"The access token is invalid or expired"}`))
|
|
return
|
|
}
|
|
mcpHTTP.ServeHTTP(w, r)
|
|
})
|
|
mcpTS := httptest.NewServer(authMux)
|
|
t.Cleanup(mcpTS.Close)
|
|
|
|
// Track LLM interactions.
|
|
var (
|
|
callCount atomic.Int32
|
|
llmToolNames []string
|
|
llmToolsMu sync.Mutex
|
|
foundMCPResult atomic.Bool
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
if callCount.Add(1) == 1 {
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
llmToolsMu.Lock()
|
|
llmToolNames = names
|
|
llmToolsMu.Unlock()
|
|
|
|
// Ask the LLM to call the MCP echo tool.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"authed-mcp__echo",
|
|
`{"input":"hello via refreshed token"}`,
|
|
),
|
|
)
|
|
}
|
|
|
|
// Second call: verify the tool result was fed back.
|
|
for _, msg := range req.Messages {
|
|
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: hello via refreshed token") {
|
|
foundMCPResult.Store(true)
|
|
}
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Done!")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Seed the MCP server config with OAuth2 auth pointing to our
|
|
// mock token endpoint.
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Authed MCP",
|
|
Slug: "authed-mcp",
|
|
Url: mcpTS.URL,
|
|
AuthType: "oauth2",
|
|
OAuth2ClientID: "test-client-id",
|
|
OAuth2TokenURL: tokenSrv.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
|
|
// Seed an expired OAuth2 token with a valid refresh_token.
|
|
_, err := db.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
|
|
MCPServerConfigID: mcpConfig.ID,
|
|
UserID: user.ID,
|
|
AccessToken: "old-expired-access-token",
|
|
RefreshToken: "old-refresh-token",
|
|
TokenType: "Bearer",
|
|
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "oauth2-refresh-test",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Echo something via the authed MCP."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to finish processing.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// The token should have been refreshed.
|
|
require.Greater(t, refreshCalled.Load(), int32(0),
|
|
"OAuth2 token endpoint should have been called to refresh the expired token")
|
|
|
|
// The MCP tool should appear in the tool list.
|
|
llmToolsMu.Lock()
|
|
recordedNames := append([]string(nil), llmToolNames...)
|
|
llmToolsMu.Unlock()
|
|
require.Contains(t, recordedNames, "authed-mcp__echo",
|
|
"MCP tool should be in the tool list sent to the LLM")
|
|
|
|
// The tool result should have been fed back to the LLM.
|
|
require.True(t, foundMCPResult.Load(),
|
|
"MCP tool result should appear in the second LLM call")
|
|
|
|
// Verify the refreshed token was persisted to the database.
|
|
dbToken, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
|
MCPServerConfigID: mcpConfig.ID,
|
|
UserID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, freshAccessToken, dbToken.AccessToken,
|
|
"refreshed access token should be persisted in the database")
|
|
require.Equal(t, "rotated-refresh-token", dbToken.RefreshToken,
|
|
"rotated refresh token should be persisted in the database")
|
|
}
|
|
|
|
// TestMCPServerOAuth2TokenRefreshFailureGraceful verifies that when
|
|
// the OAuth2 token endpoint is down, the chat still proceeds without
|
|
// the MCP server's tools. The expired token is preserved unchanged.
|
|
func TestMCPServerOAuth2TokenRefreshFailureGraceful(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Token endpoint that always returns an error.
|
|
tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
_, _ = w.Write([]byte(`{"error":"server_error","error_description":"token endpoint unavailable"}`))
|
|
}))
|
|
t.Cleanup(tokenSrv.Close)
|
|
|
|
// The LLM just replies with text, no tool calls.
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
callCount.Add(1)
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("I responded without MCP tools.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
mcpConfig := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Broken MCP",
|
|
Slug: "broken-mcp",
|
|
Url: "http://127.0.0.1:0/does-not-exist",
|
|
AuthType: "oauth2",
|
|
OAuth2ClientID: "test-client-id",
|
|
OAuth2TokenURL: tokenSrv.URL,
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
})
|
|
_, err := db.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
|
|
MCPServerConfigID: mcpConfig.ID,
|
|
UserID: user.ID,
|
|
AccessToken: "old-expired-token",
|
|
RefreshToken: "old-refresh-token",
|
|
TokenType: "Bearer",
|
|
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
|
|
})
|
|
|
|
require.NoError(t, err)
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "graceful-degradation-test",
|
|
ModelConfigID: model.ID,
|
|
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Hello, just reply."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Chat should finish successfully despite the failed refresh.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat should not fail", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// The LLM should have been called at least once.
|
|
require.Greater(t, callCount.Load(), int32(0),
|
|
"LLM should be called even when MCP token refresh fails")
|
|
|
|
// The original token should be unchanged in the database.
|
|
dbToken, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
|
|
MCPServerConfigID: mcpConfig.ID,
|
|
UserID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "old-expired-token", dbToken.AccessToken,
|
|
"original token should be preserved when refresh fails")
|
|
}
|
|
|
|
func TestChatTemplateAllowlistEnforcement(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
// Declare templates before the handler so the closure can
|
|
// reference their IDs when building tool-call arguments.
|
|
var tplAllowed, tplBlocked database.Template
|
|
|
|
// Set up a mock OpenAI server that chains tool calls:
|
|
// 1. list_templates
|
|
// 2. read_template (blocked template, should fail)
|
|
// 3. read_template (allowed template, should succeed)
|
|
// 4. create_workspace (blocked template, should fail)
|
|
// 5. text response
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
switch callCount.Add(1) {
|
|
case 1:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("list_templates", `{}`),
|
|
)
|
|
case 2:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("read_template",
|
|
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
|
|
)
|
|
case 3:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("read_template",
|
|
fmt.Sprintf(`{"template_id":%q}`, tplAllowed.ID.String())),
|
|
)
|
|
case 4:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace",
|
|
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
|
|
)
|
|
default:
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Done testing.")...,
|
|
)
|
|
}
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Create two templates the user can see.
|
|
tplAllowed = dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
Name: "allowed-template",
|
|
})
|
|
tplBlocked = dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
Name: "blocked-template",
|
|
})
|
|
|
|
// Set the allowlist to only tplAllowed.
|
|
allowlistJSON, err := json.Marshal([]string{tplAllowed.ID.String()})
|
|
require.NoError(t, err)
|
|
err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON))
|
|
require.NoError(t, err)
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
// Provide a CreateWorkspace function so the tool reaches
|
|
// the allowlist check instead of bailing with "not
|
|
// configured". If the allowlist is enforced correctly
|
|
// this function will never be called.
|
|
cfg.CreateWorkspace = func(
|
|
_ context.Context,
|
|
_ uuid.UUID,
|
|
_ codersdk.CreateWorkspaceRequest,
|
|
) (codersdk.Workspace, error) {
|
|
t.Error("CreateWorkspace should not be called for a blocked template")
|
|
return codersdk.Workspace{}, xerrors.New("unexpected call")
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "allowlist-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Test allowlist enforcement"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to finish processing.
|
|
var chatResult database.Chat
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatResult = got
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
|
|
// Collect all tool results keyed by tool name. Each tool may
|
|
// have been called more than once, so we store a slice.
|
|
var toolResults map[string][]string
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
toolResults = map[string][]string{}
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult {
|
|
toolResults[part.ToolName] = append(
|
|
toolResults[part.ToolName], string(part.Result))
|
|
}
|
|
}
|
|
}
|
|
// We expect results from all four tool calls.
|
|
return len(toolResults["list_templates"]) >= 1 &&
|
|
len(toolResults["read_template"]) >= 2 &&
|
|
len(toolResults["create_workspace"]) >= 1
|
|
}, testutil.IntervalFast)
|
|
|
|
// list_templates: only the allowed template should appear.
|
|
require.Contains(t, toolResults["list_templates"][0], tplAllowed.ID.String(),
|
|
"allowed template should appear in list_templates result")
|
|
require.NotContains(t, toolResults["list_templates"][0], tplBlocked.ID.String(),
|
|
"blocked template should NOT appear in list_templates result")
|
|
|
|
// read_template: blocked ID → error, allowed ID → success.
|
|
require.Contains(t, toolResults["read_template"][0], "not found",
|
|
"read_template for blocked template should return not-found error")
|
|
require.Contains(t, toolResults["read_template"][1], tplAllowed.ID.String(),
|
|
"read_template for allowed template should return template details")
|
|
|
|
// create_workspace: blocked ID → rejected.
|
|
require.Contains(t, toolResults["create_workspace"][0], "not available",
|
|
"create_workspace for blocked template should be rejected")
|
|
}
|
|
|
|
// TestSignalWakeImmediateAcquisition verifies that CreateChat triggers
|
|
// immediate processing via signalWake without waiting for the polling
|
|
// ticker to fire. The ticker interval is set to an hour so it never
|
|
// fires during the test. Any processing must come from the wake
|
|
// channel.
|
|
func TestSignalWakeImmediateAcquisition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
processed := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
// Signal that the LLM was reached. This proves the chat
|
|
// was acquired and processing started.
|
|
select {
|
|
case <-processed:
|
|
default:
|
|
close(processed)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("hello from the model")...,
|
|
)
|
|
})
|
|
|
|
// Use a 1-hour acquire interval so the ticker never fires.
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.PendingChatAcquireInterval = time.Hour
|
|
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
// CreateChat sets status=pending and calls signalWake().
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "wake-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The chat should be processed immediately. The LLM handler
|
|
// closes the `processed` channel when it receives a streaming
|
|
// request. Without signalWake this would hang forever because
|
|
// the 1-hour ticker never fires.
|
|
testutil.TryReceive(ctx, t, processed)
|
|
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
// Verify the chat was fully processed.
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status,
|
|
"chat should be in waiting status after processing completes")
|
|
}
|
|
|
|
// TestSignalWakeSendMessage verifies that SendMessage on an idle chat
|
|
// triggers immediate processing via signalWake.
|
|
func TestSignalWakeSendMessage(t *testing.T) {
|
|
t.Parallel()
|
|
// TODO(CODAGT-353): Re-enable this after the chatd notification
|
|
// flow can distinguish stale status notifications from interrupts.
|
|
t.Skip("skipped until chatd notification flow refactor handles stale control notifications")
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
|
|
firstProcessed := make(chan struct{})
|
|
var requestCount atomic.Int32
|
|
secondProcessed := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
switch requestCount.Add(1) {
|
|
case 1:
|
|
select {
|
|
case <-firstProcessed:
|
|
default:
|
|
close(firstProcessed)
|
|
}
|
|
case 2:
|
|
close(secondProcessed)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("response")...,
|
|
)
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.PendingChatAcquireInterval = time.Hour
|
|
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
|
|
})
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
// CreateChat triggers wake -> processes first turn.
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "wake-send-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the first turn to actually reach the LLM, then
|
|
// wait for the processing goroutine to finish so the chat
|
|
// transitions to "waiting" status.
|
|
testutil.TryReceive(ctx, t, firstProcessed)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
// Now send a follow-up message, which should also be
|
|
// processed immediately via signalWake.
|
|
_, err = server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("second")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.TryReceive(ctx, t, secondProcessed)
|
|
chatd.WaitUntilIdleForTest(server)
|
|
|
|
// Both turns processed. Verify the second request reached the LLM.
|
|
require.GreaterOrEqual(t, requestCount.Load(), int32(2),
|
|
"LLM should have received at least 2 streaming requests")
|
|
}
|
|
|
|
// TestAgentContextFilesAndSkillsLoadedIntoChat verifies the full
|
|
// end-to-end path: the workspace agent reads instruction files and
|
|
// discovers skills from the filesystem, chatd fetches them via a
|
|
// real tailnet agent connection, and both the <workspace-context>
|
|
// block and <available-skills> index appear in the LLM prompt.
|
|
//
|
|
// This test is NOT parallel because it sets process-wide environment
|
|
// variables via t.Setenv to configure the agent's context config.
|
|
func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
|
|
fakeHome := t.TempDir()
|
|
t.Setenv("HOME", fakeHome)
|
|
t.Setenv("USERPROFILE", fakeHome)
|
|
|
|
instructionsDir := filepath.Join(fakeHome, ".coder")
|
|
skillsDir := filepath.Join(fakeHome, ".coder", "skills")
|
|
require.NoError(t, os.MkdirAll(instructionsDir, 0o755))
|
|
require.NoError(t, os.MkdirAll(skillsDir, 0o755))
|
|
|
|
t.Setenv(agentcontextconfig.EnvInstructionsDirs, instructionsDir)
|
|
t.Setenv(agentcontextconfig.EnvInstructionsFile, "AGENTS.md")
|
|
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
|
|
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "SKILL.md")
|
|
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, filepath.Join(fakeHome, "nonexistent-mcp.json"))
|
|
|
|
require.NoError(t, os.WriteFile(
|
|
filepath.Join(instructionsDir, "AGENTS.md"),
|
|
[]byte("# Project Rules\nAlways write tests."),
|
|
0o600,
|
|
))
|
|
|
|
skillDir := filepath.Join(skillsDir, "my-cool-skill")
|
|
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
|
require.NoError(t, os.WriteFile(
|
|
filepath.Join(skillDir, "SKILL.md"),
|
|
[]byte("---\nname: my-cool-skill\ndescription: A test skill\n---\nDo the cool thing.\n"),
|
|
0o600,
|
|
))
|
|
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
deploymentValues := directChatRoutingDeploymentValues(t)
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
ChatdInstructionLookupTimeout: testutil.WaitLong,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
expClient := codersdk.NewExperimentalClient(client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
|
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
|
|
|
_ = agenttest.New(t, client.URL, agentToken, agenttest.WithContextConfigFromEnv())
|
|
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
|
|
|
|
// Capture LLM requests so we can inspect the system prompt.
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("context test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Got it.")...,
|
|
)
|
|
})
|
|
|
|
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
|
|
|
workspaceID := workspace.ID
|
|
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
OrganizationID: user.OrganizationID,
|
|
WorkspaceID: &workspaceID,
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Hello, what are the project rules?",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := expClient.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitSuperLong, testutil.IntervalFast)
|
|
|
|
streamedCallsMu.Lock()
|
|
recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.NotEmpty(t, recordedCalls, "LLM should have received at least one streaming request")
|
|
|
|
var allSystemContent string
|
|
for _, msg := range recordedCalls[0] {
|
|
if msg.Role == "system" {
|
|
allSystemContent += msg.Content + "\n"
|
|
}
|
|
}
|
|
|
|
require.Contains(t, allSystemContent, "<workspace-context>",
|
|
"system prompt should contain workspace-context block")
|
|
require.Contains(t, allSystemContent, "Always write tests.",
|
|
"system prompt should contain AGENTS.md content")
|
|
require.Contains(t, allSystemContent, "AGENTS.md",
|
|
"system prompt should reference the source file")
|
|
|
|
planBlockCount := 0
|
|
standalonePlanBlockCount := 0
|
|
for _, msg := range recordedCalls[0] {
|
|
if msg.Role != "system" {
|
|
continue
|
|
}
|
|
planBlockCount += strings.Count(
|
|
msg.Content,
|
|
"<plan-file-path>\nYour plan file path for this chat is:",
|
|
)
|
|
trimmed := strings.TrimSpace(msg.Content)
|
|
if strings.HasPrefix(trimmed, "<plan-file-path>") &&
|
|
strings.HasSuffix(trimmed, "</plan-file-path>") {
|
|
standalonePlanBlockCount++
|
|
}
|
|
}
|
|
|
|
require.Contains(t, allSystemContent, "<available-skills>",
|
|
"system prompt should contain available-skills block")
|
|
require.Contains(t, allSystemContent, "my-cool-skill",
|
|
"system prompt should list the discovered skill")
|
|
require.Contains(t, allSystemContent, "A test skill",
|
|
"system prompt should include the skill description")
|
|
require.Contains(t, allSystemContent, "<plan-file-path>",
|
|
"system prompt should contain the plan-file-path block")
|
|
require.Contains(t, allSystemContent, "PLAN-"+chat.ID.String()+".md",
|
|
"system prompt should use the chat-specific plan path")
|
|
require.Contains(t, allSystemContent,
|
|
"Do not use "+strings.TrimRight(fakeHome, "/")+"/PLAN.md.",
|
|
"system prompt should warn against the home-root plan path")
|
|
require.Equal(t, 1, planBlockCount,
|
|
"system prompt should contain a single plan-file-path block")
|
|
require.Zero(t, standalonePlanBlockCount,
|
|
"plan-file-path block should be part of the main system prompt, not a standalone message")
|
|
}
|
|
|
|
func TestSendMessageRejectsArchivedChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "send-archived",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("should fail")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrChatArchived)
|
|
}
|
|
|
|
func TestEditMessageRejectsArchivedChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "edit-archived",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: messages[0].ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrChatArchived)
|
|
}
|
|
|
|
// TestEditMessageWithModelConfigOverride verifies that callers can
|
|
// change the model when editing a previous user message. The
|
|
// replacement message must persist with the new model and the chat's
|
|
// LastModelConfigID must be advanced so the assistant turn that follows
|
|
// runs against the new selection.
|
|
func TestEditMessageWithModelConfigOverride(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelA := seedChatDependencies(t, db)
|
|
modelB := insertChatModelConfigWithCallConfig(
|
|
t,
|
|
db,
|
|
user.ID,
|
|
"openai",
|
|
"gpt-4o-mini-edit-"+uuid.NewString(),
|
|
codersdk.ChatModelCallConfig{},
|
|
)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "edit-with-model-override",
|
|
ModelConfigID: modelA.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initial, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initial, 1)
|
|
require.Equal(t, modelA.ID, initial[0].ModelConfigID.UUID)
|
|
|
|
result, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: initial[0].ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
ModelConfigID: modelB.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Message.ModelConfigID.Valid)
|
|
require.Equal(t, modelB.ID, result.Message.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelB.ID, storedChat.LastModelConfigID,
|
|
"edit must update last_model_config_id so the assistant turn picks up the new model")
|
|
}
|
|
|
|
// TestEditMessagePreservesModelConfigByDefault verifies that omitting
|
|
// ModelConfigID on edit keeps the original message's model. This is the
|
|
// existing default for callers that only edit the text.
|
|
func TestEditMessagePreservesModelConfigByDefault(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelA := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "edit-preserves-model",
|
|
ModelConfigID: modelA.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initial, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initial, 1)
|
|
|
|
result, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: initial[0].ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Message.ModelConfigID.Valid)
|
|
require.Equal(t, modelA.ID, result.Message.ModelConfigID.UUID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelA.ID, storedChat.LastModelConfigID,
|
|
"edit without model override must not change last_model_config_id")
|
|
}
|
|
|
|
// TestEditMessageRejectsUnknownModelConfig verifies the edit handler
|
|
// returns ErrInvalidModelConfigID when the requested model does not
|
|
// exist, mirroring SendMessage's validation.
|
|
func TestEditMessageRejectsUnknownModelConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, modelA := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "edit-unknown-model",
|
|
ModelConfigID: modelA.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initial, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initial, 1)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: initial[0].ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
|
ModelConfigID: uuid.New(),
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrInvalidModelConfigID)
|
|
|
|
// The edit must roll back: the original message should still be
|
|
// present and the chat's LastModelConfigID unchanged.
|
|
stillThere, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, stillThere, 1)
|
|
require.Equal(t, initial[0].ID, stillThere[0].ID)
|
|
|
|
storedChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelA.ID, storedChat.LastModelConfigID)
|
|
}
|
|
|
|
func TestPromoteQueuedRejectsArchivedChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "promote-archived",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Queue a message by setting the chat to running first.
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
|
|
// Move back to waiting, then archive.
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrChatArchived)
|
|
}
|
|
|
|
// TestPromoteQueuedWhileRequiresAction guards against the
|
|
// stops-dead failure mode: promoting on requires_action without
|
|
// closing pending dynamic tool calls leaves the assistant turn
|
|
// with unresolved tool_call parts that the LLM API rejects. It
|
|
// also asserts the synthetic tool-result row is published to live
|
|
// SSE subscribers before the promoted user message.
|
|
func TestPromoteQueuedWhileRequiresAction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("requires-action-promote")
|
|
}
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello"}`,
|
|
),
|
|
)
|
|
}
|
|
// Second call: the resumed run after promote completes.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Resumed after promotion.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "promote-while-requires-action",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Please call the dynamic tool."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatBeforePromote database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatBeforePromote = got
|
|
return got.Status == database.ChatStatusRequiresAction ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.IntervalFast)
|
|
require.Equal(t, database.ChatStatusRequiresAction, chatBeforePromote.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
chatBeforePromote.Status, chatLastErrorMessage(chatBeforePromote.LastError))
|
|
|
|
var pendingToolCallID string
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleAssistant {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
|
pendingToolCallID = part.ToolCallID
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}, testutil.IntervalFast)
|
|
require.NotEmpty(t, pendingToolCallID, "expected pending dynamic tool call")
|
|
|
|
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
|
|
// Subscribe before promoting to capture published events.
|
|
_, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
defer subCancel()
|
|
promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatMessageRoleUser, promoteResult.PromotedMessage.Role)
|
|
|
|
// Synthetic row must publish before the promoted user message.
|
|
var (
|
|
syntheticPublishedAt int
|
|
userPublishedAt int
|
|
messagesSeen int
|
|
)
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
for {
|
|
select {
|
|
case ev := <-events:
|
|
if ev.Type != codersdk.ChatStreamEventTypeMessage || ev.Message == nil {
|
|
continue
|
|
}
|
|
messagesSeen++
|
|
switch ev.Message.Role {
|
|
case codersdk.ChatMessageRoleTool:
|
|
if syntheticPublishedAt == 0 {
|
|
syntheticPublishedAt = messagesSeen
|
|
}
|
|
case codersdk.ChatMessageRoleUser:
|
|
if ev.Message.ID == promoteResult.PromotedMessage.ID {
|
|
userPublishedAt = messagesSeen
|
|
}
|
|
}
|
|
if syntheticPublishedAt > 0 && userPublishedAt > 0 {
|
|
return true
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
require.Less(t, syntheticPublishedAt, userPublishedAt,
|
|
"synthetic tool-result must be published before the promoted user message")
|
|
|
|
queuedAfter, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, queuedAfter, "queued message should be removed after sync promotion")
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
syntheticToolResult *database.ChatMessage
|
|
promotedUserMessage *database.ChatMessage
|
|
)
|
|
for i := range messages {
|
|
msg := messages[i]
|
|
if msg.Role == database.ChatMessageRoleTool {
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult {
|
|
continue
|
|
}
|
|
if part.ToolCallID != pendingToolCallID {
|
|
continue
|
|
}
|
|
require.True(t, part.IsError,
|
|
"synthetic tool result should have IsError=true")
|
|
syntheticToolResult = &messages[i]
|
|
}
|
|
}
|
|
if msg.ID == promoteResult.PromotedMessage.ID {
|
|
promotedUserMessage = &messages[i]
|
|
}
|
|
}
|
|
require.NotNil(t, syntheticToolResult,
|
|
"expected a synthetic error tool result for the pending tool call")
|
|
require.NotNil(t, promotedUserMessage)
|
|
require.Less(t, syntheticToolResult.ID, promotedUserMessage.ID,
|
|
"synthetic tool result must precede the promoted user message")
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
|
}, testutil.IntervalFast)
|
|
final, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, final.Status,
|
|
"chat should resume to waiting after promotion (last_error=%q)",
|
|
chatLastErrorMessage(final.LastError))
|
|
}
|
|
|
|
// TestPromoteQueuedWhileRequiresActionMixedTools guards against
|
|
// duplicating already-resolved built-in tool results: synthetic
|
|
// results must be scoped to dynamic tool names only.
|
|
func TestPromoteQueuedWhileRequiresActionMixedTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("mixed-tools-promote")
|
|
}
|
|
if streamedCallCount.Add(1) == 1 {
|
|
builtinChunk := chattest.OpenAIToolCallChunk(
|
|
"read_file",
|
|
`{"path":"/tmp/test.txt"}`,
|
|
)
|
|
dynamicChunk := chattest.OpenAIToolCallChunk(
|
|
"my_dynamic_tool",
|
|
`{"input":"hello world"}`,
|
|
)
|
|
mergedChunk := builtinChunk
|
|
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
|
|
dynCall.Index = 1
|
|
mergedChunk.Choices[0].ToolCalls = append(
|
|
mergedChunk.Choices[0].ToolCalls,
|
|
dynCall,
|
|
)
|
|
return chattest.OpenAIStreamingResponse(mergedChunk)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Resumed after mixed-tool promotion.")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "promote-while-requires-action-mixed",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Call both tools."),
|
|
},
|
|
DynamicTools: dynamicToolsJSON,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatBeforePromote database.Chat
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatBeforePromote = got
|
|
return got.Status == database.ChatStatusRequiresAction ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.IntervalFast)
|
|
require.Equal(t, database.ChatStatusRequiresAction, chatBeforePromote.Status,
|
|
"expected requires_action, got %s (last_error=%q)",
|
|
chatBeforePromote.Status, chatLastErrorMessage(chatBeforePromote.LastError))
|
|
|
|
// The built-in tool resolves before requires_action; capture
|
|
// its row ID to assert the dynamic synthetic comes after.
|
|
var (
|
|
dynamicToolCallID string
|
|
builtinToolResultID int64
|
|
builtinToolResultSeen bool
|
|
)
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
for _, msg := range messages {
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
|
|
builtinToolResultID = msg.ID
|
|
builtinToolResultSeen = true
|
|
}
|
|
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
|
dynamicToolCallID = part.ToolCallID
|
|
}
|
|
}
|
|
}
|
|
return builtinToolResultSeen && dynamicToolCallID != ""
|
|
}, testutil.IntervalFast)
|
|
require.NotEmpty(t, dynamicToolCallID)
|
|
require.NotZero(t, builtinToolResultID)
|
|
|
|
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
|
|
_, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
defer subCancel()
|
|
promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, promoteResult.PromotedMessage.ID,
|
|
"requires_action promotion is synchronous and returns the inserted message")
|
|
|
|
// Only the dynamic tool's synth row publishes; the built-in's
|
|
// pre-existing result is not republished.
|
|
var (
|
|
syntheticPublishCount int
|
|
userPublished bool
|
|
)
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
for {
|
|
select {
|
|
case ev := <-events:
|
|
if ev.Type != codersdk.ChatStreamEventTypeMessage || ev.Message == nil {
|
|
t.Logf("subscriber consumed non-message event type=%s", ev.Type)
|
|
continue
|
|
}
|
|
t.Logf("subscriber consumed message id=%d role=%s match_promoted=%t", ev.Message.ID, ev.Message.Role, ev.Message.ID == promoteResult.PromotedMessage.ID)
|
|
switch ev.Message.Role {
|
|
case codersdk.ChatMessageRoleTool:
|
|
syntheticPublishCount++
|
|
case codersdk.ChatMessageRoleUser:
|
|
if ev.Message.ID == promoteResult.PromotedMessage.ID {
|
|
userPublished = true
|
|
}
|
|
}
|
|
if userPublished {
|
|
return true
|
|
}
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
require.Equal(t, 1, syntheticPublishCount,
|
|
"only the dynamic tool's synthetic result must be published; the built-in's pre-existing result must not be republished")
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
dynamicSyntheticCount int
|
|
builtinResultsForReadFile int
|
|
)
|
|
for _, msg := range messages {
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult {
|
|
continue
|
|
}
|
|
switch part.ToolName {
|
|
case "read_file":
|
|
builtinResultsForReadFile++
|
|
case "my_dynamic_tool":
|
|
if part.IsError && part.ToolCallID == dynamicToolCallID && msg.ID > builtinToolResultID {
|
|
dynamicSyntheticCount++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.Equal(t, 1, dynamicSyntheticCount,
|
|
"expected exactly one synthetic error tool result for the dynamic tool call")
|
|
require.Equal(t, 1, builtinResultsForReadFile,
|
|
"built-in tool result should not be duplicated by promotion")
|
|
|
|
require.Greater(t, promoteResult.PromotedMessage.ID, builtinToolResultID)
|
|
}
|
|
|
|
func TestSubmitToolResultsRejectsArchivedChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "submit-tool-archived",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.ArchiveChat(ctx, chat)
|
|
require.NoError(t, err)
|
|
|
|
// Set requires_action so the test exercises a realistic
|
|
// scenario where SubmitToolResults would be called.
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRequiresAction,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = replica.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
|
ChatID: chat.ID,
|
|
UserID: user.ID,
|
|
ModelConfigID: model.ID,
|
|
Results: []codersdk.ToolResult{{
|
|
ToolCallID: "fake-tool-call-id",
|
|
Output: json.RawMessage(`{"result":"ignored"}`),
|
|
}},
|
|
})
|
|
require.ErrorIs(t, err, chatd.ErrChatArchived)
|
|
}
|
|
|
|
func TestAcquireChatsSkipsArchivedPendingChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
_ = newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
archivedChat := dbgen.Chat(t, db, database.Chat{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "acquire-skip-archived",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
|
|
// Archive the chat, then force it to pending.
|
|
_, err := db.ArchiveChatByID(ctx, archivedChat.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: archivedChat.ID,
|
|
Status: database.ChatStatusPending,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a second, non-archived pending chat so the result
|
|
// slice is non-empty and the assertion is not vacuously true.
|
|
activeChat := dbgen.Chat(t, db, database.Chat{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "acquire-active",
|
|
LastModelConfigID: model.ID,
|
|
Status: database.ChatStatusPending,
|
|
})
|
|
|
|
now := time.Now()
|
|
acquired, err := db.AcquireChats(ctx, database.AcquireChatsParams{
|
|
WorkerID: uuid.New(),
|
|
StartedAt: now,
|
|
NumChats: 10,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, acquired, 1, "only the non-archived chat should be acquired")
|
|
require.Equal(t, activeChat.ID, acquired[0].ID)
|
|
}
|
|
|
|
func TestAdvisorGating_Disabled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var toolsMu sync.Mutex
|
|
var capturedTools []string
|
|
var capturedMessages []chattest.OpenAIMessage
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
capturedTools = names
|
|
capturedMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
toolsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("advisor is not available")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: false,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-disabled",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("hello"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
tools := append([]string(nil), capturedTools...)
|
|
messages := append([]chattest.OpenAIMessage(nil), capturedMessages...)
|
|
toolsMu.Unlock()
|
|
|
|
require.NotEmpty(t, messages, "expected a streamed LLM request")
|
|
require.NotContains(t, tools, "advisor",
|
|
"advisor tool should not be registered when disabled")
|
|
for _, msg := range messages {
|
|
require.NotContains(t, msg.Content, chatadvisor.ParentGuidanceBlock,
|
|
"advisor guidance should not be injected when disabled")
|
|
}
|
|
}
|
|
|
|
func TestAdvisorGating_RootChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
var firstCallTools []string
|
|
var firstCallMessages []chattest.OpenAIMessage
|
|
var secondCallMessages []chattest.OpenAIMessage
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch streamedCallCount.Add(1) {
|
|
case 1:
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
streamedCallsMu.Lock()
|
|
firstCallTools = names
|
|
firstCallMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
streamedCallsMu.Unlock()
|
|
|
|
advisorChunk := chattest.OpenAIToolCallChunk(
|
|
"advisor",
|
|
`{"question":"help me plan"}`,
|
|
)
|
|
readChunk := chattest.OpenAIToolCallChunk(
|
|
"read_file",
|
|
`{"path":"/tmp/test.txt"}`,
|
|
)
|
|
mergedChunk := advisorChunk
|
|
readCall := readChunk.Choices[0].ToolCalls[0]
|
|
readCall.Index = 1
|
|
mergedChunk.Choices[0].ToolCalls = append(
|
|
mergedChunk.Choices[0].ToolCalls,
|
|
readCall,
|
|
)
|
|
return chattest.OpenAIStreamingResponse(mergedChunk)
|
|
case 2:
|
|
streamedCallsMu.Lock()
|
|
secondCallMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
streamedCallsMu.Unlock()
|
|
}
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-root",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("help me plan this"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Status != database.ChatStatusWaiting &&
|
|
got.Status != database.ChatStatusError {
|
|
return false
|
|
}
|
|
return streamedCallCount.Load() >= 2
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
streamedCallsMu.Lock()
|
|
tools := append([]string(nil), firstCallTools...)
|
|
messages := append([]chattest.OpenAIMessage(nil), firstCallMessages...)
|
|
secondMessages := append([]chattest.OpenAIMessage(nil), secondCallMessages...)
|
|
streamedCallsMu.Unlock()
|
|
|
|
// Exactly two streamed LLM calls are expected: the first that
|
|
// returned the mixed advisor + read_file batch, and the second
|
|
// that received the exclusive-policy rejection. A third call
|
|
// would indicate that either tool had slipped past the exclusive
|
|
// policy; the >= 2 wait would have missed that regression.
|
|
require.Equal(t, int32(2), streamedCallCount.Load(),
|
|
"exclusive policy must block execution of both tools; no third call expected")
|
|
require.NotEmpty(t, messages, "expected a first streamed LLM request")
|
|
require.NotEmpty(t, secondMessages, "expected a second streamed LLM request")
|
|
require.Contains(t, tools, "advisor",
|
|
"advisor tool should be registered for root chats when enabled")
|
|
|
|
var hasGuidance bool
|
|
for _, msg := range messages {
|
|
if strings.Contains(msg.Content, chatadvisor.ParentGuidanceBlock) {
|
|
hasGuidance = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, hasGuidance,
|
|
"root chat should contain advisor guidance in the prompt")
|
|
|
|
var hasExclusiveAdvisorError bool
|
|
var hasSkippedToolError bool
|
|
for _, msg := range secondMessages {
|
|
if strings.Contains(msg.Content, "advisor must be called alone") {
|
|
hasExclusiveAdvisorError = true
|
|
}
|
|
if strings.Contains(msg.Content, "this tool was skipped because advisor must run alone") {
|
|
hasSkippedToolError = true
|
|
}
|
|
}
|
|
require.True(t, hasExclusiveAdvisorError,
|
|
"mixed advisor batches should surface the exclusive advisor error")
|
|
require.True(t, hasSkippedToolError,
|
|
"mixed advisor batches should skip sibling tools with an explanatory error")
|
|
}
|
|
|
|
// TestAdvisorHappyPath_RootChat walks the advisor tool end-to-end:
|
|
// parent calls advisor alone, the nested advisor call produces text, and
|
|
// the structured result flows back into the parent conversation. The
|
|
// exclusive-policy test above only proves the rejection path; this test
|
|
// covers the glue from chatd wiring -> chatadvisor.Tool -> Runtime.Run ->
|
|
// nested model call -> structured result back to the outer model.
|
|
func TestAdvisorHappyPath_RootChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const advisorReply = "break the problem into smaller pieces first"
|
|
advisorDeltas := []string{"break the problem ", "into smaller pieces first"}
|
|
|
|
var (
|
|
streamedCallCount atomic.Int32
|
|
streamedCallsMu sync.Mutex
|
|
advisorCallSeen atomic.Bool
|
|
advisorMessages []chattest.OpenAIMessage
|
|
finalCallMessages []chattest.OpenAIMessage
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
switch streamedCallCount.Add(1) {
|
|
case 1:
|
|
// Parent turn 1: call advisor solo.
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk(
|
|
"advisor",
|
|
`{"question":"how should I approach this refactor?"}`,
|
|
))
|
|
case 2:
|
|
// Nested advisor turn. The nested call has no tools because
|
|
// chatadvisor.RunAdvisor runs with MaxSteps=1 and no tool
|
|
// set.
|
|
require.Empty(t, req.Tools,
|
|
"advisor's nested call must run without tools")
|
|
streamedCallsMu.Lock()
|
|
advisorMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
streamedCallsMu.Unlock()
|
|
advisorCallSeen.Store(true)
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(advisorDeltas...)...,
|
|
)
|
|
default:
|
|
// Parent turn 2: observe the advisor tool result and close
|
|
// out with a final text reply.
|
|
streamedCallsMu.Lock()
|
|
finalCallMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
streamedCallsMu.Unlock()
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("acknowledged")...,
|
|
)
|
|
}
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-happy-path",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("help me refactor this module"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Subscribe before the worker commits any durable messages so we
|
|
// observe the advisor tool-result deltas live. Buffered parts are
|
|
// claimed by their committed durable message ID at publishMessage
|
|
// time and dropped from snapshots of late-connecting subscribers, so
|
|
// a post-completion Subscribe() would no longer see streaming
|
|
// deltas. Collecting events from the live channel covers the
|
|
// streaming UX contract this test exists to verify.
|
|
_, liveEvents, cancelLive, ok := server.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
var (
|
|
livePartsMu sync.Mutex
|
|
liveAdvisorDeltas []string
|
|
liveCollectorDone = make(chan struct{})
|
|
)
|
|
go func() {
|
|
defer close(liveCollectorDone)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case event, eventsOK := <-liveEvents:
|
|
if !eventsOK {
|
|
return
|
|
}
|
|
if event.Type != codersdk.ChatStreamEventTypeMessagePart ||
|
|
event.MessagePart == nil {
|
|
continue
|
|
}
|
|
part := event.MessagePart.Part
|
|
if event.MessagePart.Role != codersdk.ChatMessageRoleTool ||
|
|
part.Type != codersdk.ChatMessagePartTypeToolResult ||
|
|
part.ToolName != chatadvisor.ToolName ||
|
|
part.ResultDelta == "" {
|
|
continue
|
|
}
|
|
livePartsMu.Lock()
|
|
liveAdvisorDeltas = append(liveAdvisorDeltas, part.ResultDelta)
|
|
livePartsMu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Status != database.ChatStatusWaiting &&
|
|
got.Status != database.ChatStatusError {
|
|
return false
|
|
}
|
|
return streamedCallCount.Load() >= 3
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
streamedCallsMu.Lock()
|
|
gotAdvisorMessages := append([]chattest.OpenAIMessage(nil), advisorMessages...)
|
|
gotFinalMessages := append([]chattest.OpenAIMessage(nil), finalCallMessages...)
|
|
streamedCallsMu.Unlock()
|
|
|
|
require.True(t, advisorCallSeen.Load(),
|
|
"the nested advisor call must execute; missing it means the tool never ran")
|
|
require.NotEmpty(t, gotAdvisorMessages,
|
|
"advisor call must receive the nested prompt messages")
|
|
require.NotEmpty(t, gotFinalMessages,
|
|
"parent must make a follow-up call after the advisor result")
|
|
|
|
var advisorSawQuestion bool
|
|
var advisorSawUserTurn bool
|
|
for _, msg := range gotAdvisorMessages {
|
|
if strings.Contains(msg.Content, "how should I approach this refactor?") {
|
|
advisorSawQuestion = true
|
|
}
|
|
if msg.Role == "user" && strings.Contains(msg.Content, "help me refactor this module") {
|
|
advisorSawUserTurn = true
|
|
}
|
|
}
|
|
require.True(t, advisorSawQuestion,
|
|
"advisor must receive the parent's question verbatim")
|
|
require.True(t, advisorSawUserTurn,
|
|
"advisor must receive the parent's conversation snapshot as nested context")
|
|
|
|
for _, msg := range gotAdvisorMessages {
|
|
require.NotContains(t, msg.Content, chatadvisor.ParentGuidanceBlock,
|
|
"ParentGuidanceBlock must be stripped before reaching the advisor")
|
|
}
|
|
|
|
var parentSawAdvisorResult bool
|
|
for _, msg := range gotFinalMessages {
|
|
if msg.Role == "tool" && strings.Contains(msg.Content, advisorReply) {
|
|
parentSawAdvisorResult = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, parentSawAdvisorResult,
|
|
"parent must see the advisor reply in its continuation call")
|
|
|
|
// Stop the live collector and assert it captured the streaming
|
|
// advisor deltas during processing. Late subscribers no longer
|
|
// see committed parts because publishMessage claims them out of
|
|
// new snapshots, so the assertion must use the live collector.
|
|
cancelLive()
|
|
<-liveCollectorDone
|
|
livePartsMu.Lock()
|
|
collectedAdvisorDeltas := append([]string(nil), liveAdvisorDeltas...)
|
|
livePartsMu.Unlock()
|
|
require.Equal(t, advisorDeltas, collectedAdvisorDeltas,
|
|
"advisor nested text deltas must stream into the parent tool card")
|
|
|
|
persisted, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, msg := range persisted {
|
|
require.NotContains(t, string(msg.Content.RawMessage), "result_delta",
|
|
"advisor deltas are stream-only and must not be persisted")
|
|
}
|
|
}
|
|
|
|
// TestAdvisorGating_ChildChat guards the second dimension of the advisor
|
|
// eligibility condition: even with advisor enabled, a chat whose
|
|
// ParentChatID is set must not register the advisor tool or receive the
|
|
// advisor guidance block. Without this coverage, a refactor that removes
|
|
// or weakens the !chat.ParentChatID.Valid guard would leak advisor into
|
|
// child chats, and the recursive advisor-inside-subagent cost risk the
|
|
// guard exists to prevent would ship silently.
|
|
//
|
|
// The earlier version of this test drove the gating path through
|
|
// spawn_agent, which made it dependent on subagent wiring that changed
|
|
// repeatedly upstream. This version seeds the parent chat directly in the
|
|
// database and asks the server to create a child chat with a valid
|
|
// ParentChatID, exercising the same gating path with no subagent tooling
|
|
// in the way.
|
|
func TestAdvisorGating_ChildChat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var toolsMu sync.Mutex
|
|
var capturedTools []string
|
|
var capturedMessages []chattest.OpenAIMessage
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
capturedTools = names
|
|
capturedMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
toolsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
|
|
// Seed the parent chat directly in the database so the test server
|
|
// never executes the root turn. That keeps this test focused on the
|
|
// child-chat gating path without depending on subagent wiring.
|
|
parent := dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
LastModelConfigID: model.ID,
|
|
Title: "advisor-root-parent",
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
childChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-child",
|
|
ModelConfigID: model.ID,
|
|
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("hi"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, childChat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
tools := append([]string(nil), capturedTools...)
|
|
messages := append([]chattest.OpenAIMessage(nil), capturedMessages...)
|
|
toolsMu.Unlock()
|
|
|
|
require.NotEmpty(t, messages, "expected a streamed LLM request for the child chat")
|
|
require.NotContains(t, tools, chatadvisor.ToolName,
|
|
"advisor tool must not be registered for child chats even when enabled")
|
|
for _, msg := range messages {
|
|
require.NotContains(t, msg.Content, chatadvisor.ParentGuidanceBlock,
|
|
"child chat must not contain advisor guidance")
|
|
}
|
|
}
|
|
|
|
// TestAdvisorGating_PlanMode guards the third dimension of the advisor
|
|
// eligibility condition: plan-mode turns must not register the advisor tool
|
|
// or inject the parent guidance block. Without this test, deleting the
|
|
// !isPlanModeTurn guard would still leave the other two gating tests green
|
|
// even though advisor would now leak into plan mode.
|
|
func TestAdvisorGating_PlanMode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var toolsMu sync.Mutex
|
|
var capturedTools []string
|
|
var capturedMessages []chattest.OpenAIMessage
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
capturedTools = names
|
|
capturedMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
toolsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("plan mode reply")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-plan-mode",
|
|
ModelConfigID: model.ID,
|
|
PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("draft a plan"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
tools := append([]string(nil), capturedTools...)
|
|
messages := append([]chattest.OpenAIMessage(nil), capturedMessages...)
|
|
toolsMu.Unlock()
|
|
|
|
require.NotEmpty(t, messages, "expected a streamed LLM request")
|
|
require.NotContains(t, tools, "advisor",
|
|
"plan-mode turns must not register the advisor tool even when enabled")
|
|
for _, msg := range messages {
|
|
require.NotContains(t, msg.Content, chatadvisor.ParentGuidanceBlock,
|
|
"plan-mode turns must not inject advisor guidance")
|
|
}
|
|
}
|
|
|
|
// TestAdvisorGating_ExploreSubagent guards the fourth dimension of the
|
|
// advisor eligibility condition: Explore chats (root or subagent) run
|
|
// under allowedExploreToolNames, whose policy does not include advisor,
|
|
// so the runtime must not register the advisor tool or inject the
|
|
// parent guidance block there. Without this test, deleting the
|
|
// !isExploreSubagent guard would leave the other gating tests green
|
|
// while leaking advisor into explore chats.
|
|
func TestAdvisorGating_ExploreSubagent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var toolsMu sync.Mutex
|
|
var capturedTools []string
|
|
var capturedMessages []chattest.OpenAIMessage
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
capturedTools = names
|
|
capturedMessages = append([]chattest.OpenAIMessage(nil), req.Messages...)
|
|
toolsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("explore reply")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newActiveTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-explore",
|
|
ModelConfigID: model.ID,
|
|
ChatMode: database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("inspect the codebase"),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
toolsMu.Lock()
|
|
tools := append([]string(nil), capturedTools...)
|
|
messages := append([]chattest.OpenAIMessage(nil), capturedMessages...)
|
|
toolsMu.Unlock()
|
|
|
|
require.NotEmpty(t, messages, "expected a streamed LLM request")
|
|
require.NotContains(t, tools, chatadvisor.ToolName,
|
|
"explore chats must not register the advisor tool even when enabled")
|
|
for _, msg := range messages {
|
|
require.NotContains(t, msg.Content, chatadvisor.ParentGuidanceBlock,
|
|
"explore chats must not inject advisor guidance")
|
|
}
|
|
}
|
|
|
|
// TestAdvisorChainMode_SnapshotKeepsFullHistory exercises the advisor
|
|
// runtime together with chain mode and asserts the snapshot captured for
|
|
// the nested advisor call retains the full pre-chain prompt. Chain mode
|
|
// otherwise strips assistant and tool turns from the prompt the outer
|
|
// loop sees, so a regression that moves setAdvisorPromptSnapshot behind
|
|
// filterPromptForChainMode, or drops the !chainModeActive guards in
|
|
// PrepareMessages, would leak the filtered view into the advisor's
|
|
// nested call. The advisor would then only see the trailing user
|
|
// message, losing the context the outer model had been building on.
|
|
func TestAdvisorChainMode_SnapshotKeepsFullHistory(t *testing.T) {
|
|
t.Parallel()
|
|
// TODO(CODAGT-353): Re-enable this test after the chatd notification flow
|
|
// refactor gives workers enough causal information to distinguish stale
|
|
// control NOTIFY messages from real interrupts. The current design reuses
|
|
// the same status notification shape for wake-only and interrupt intents,
|
|
// so a stale NOTIFY can cancel a new processChat run.
|
|
t.Skip("skipped until chatd notification flow refactor handles stale control notifications")
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
const (
|
|
turn1User = "help me refactor this module"
|
|
turn1Reply = "happy to help, tell me more"
|
|
turn1RespID = "resp_turn1_advisor_chain"
|
|
turn2User = "follow up question"
|
|
advisorReply = "narrow the scope to one module"
|
|
finalReply = "acknowledged"
|
|
)
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
advisorRequestRaw []byte
|
|
advisorCallSeen atomic.Bool
|
|
)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
// The advisor's nested call runs with no tools (MaxSteps=1,
|
|
// empty tool set). Parent calls always carry the chat's tool
|
|
// set, which includes the advisor tool.
|
|
isAdvisorNested := len(req.Tools) == 0
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
if isAdvisorNested {
|
|
advisorRequestRaw = append([]byte(nil), req.RawBody...)
|
|
advisorCallSeen.Store(true)
|
|
}
|
|
requestsMu.Unlock()
|
|
|
|
if isAdvisorNested {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(advisorReply)...,
|
|
)
|
|
}
|
|
|
|
// Turn 1 parent request: no previous_response_id yet, so chain
|
|
// mode cannot activate. Respond with a plain text reply and
|
|
// tag the stored response id so turn 2 can chain off it.
|
|
if req.PreviousResponseID == nil {
|
|
resp := chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(turn1Reply)...,
|
|
)
|
|
resp.ResponseID = turn1RespID
|
|
return resp
|
|
}
|
|
|
|
// Turn 2 parent: chain mode is active. On the first pass call
|
|
// advisor; on the continuation after the tool result arrives,
|
|
// close out with a final text reply.
|
|
var hasAdvisorResult bool
|
|
for _, m := range req.Messages {
|
|
if m.Role == "tool" && strings.Contains(m.Content, advisorReply) {
|
|
hasAdvisorResult = true
|
|
break
|
|
}
|
|
}
|
|
if !hasAdvisorResult {
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk(
|
|
"advisor",
|
|
`{"question":"should I keep going?"}`,
|
|
))
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks(finalReply)...,
|
|
)
|
|
})
|
|
|
|
user, org, _ := seedChatDependenciesWithProvider(t, db, "openai", openAIURL)
|
|
storeEnabled := true
|
|
// The OpenAI Responses API is the only provider code path where
|
|
// chain mode activates. Store=true is the switch that routes this
|
|
// provider/model through the Responses API and lets
|
|
// IsResponsesStoreEnabled return true.
|
|
responsesModel := insertChatModelConfigWithCallConfig(
|
|
t, db, user.ID, "openai", "gpt-4o",
|
|
codersdk.ChatModelCallConfig{
|
|
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
|
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
|
Store: &storeEnabled,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{
|
|
Enabled: true,
|
|
MaxUsesPerRun: 3,
|
|
MaxOutputTokens: 16384,
|
|
})
|
|
server := newOpenAIResponsesTestServer(t, db, ps)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "advisor-chain-mode",
|
|
ModelConfigID: responsesModel.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(turn1User),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Turn 1 must settle before turn 2 starts so the assistant row
|
|
// with ProviderResponseID is visible to resolveChainMode.
|
|
waitForChatProcessed(ctx, t, db, chat.ID, server)
|
|
turn1Chat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, turn1Chat.Status,
|
|
"turn 1 must complete before turn 2 can be sent; last_error=%q", chatLastErrorMessage(turn1Chat.LastError))
|
|
|
|
_, err = server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
CreatedBy: user.ID,
|
|
Content: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(turn2User),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
if !advisorCallSeen.Load() {
|
|
return false
|
|
}
|
|
got, getErr := db.GetChatByID(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
return got.Status == database.ChatStatusWaiting ||
|
|
got.Status == database.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
requestsMu.Lock()
|
|
gotAdvisorBody := append([]byte(nil), advisorRequestRaw...)
|
|
gotRequests := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
|
|
// Chain mode must have actually fired on turn 2, otherwise this
|
|
// test degenerates to TestAdvisorHappyPath_RootChat.
|
|
var chainModeActivated bool
|
|
for _, r := range gotRequests {
|
|
if r.PreviousResponseID != nil && *r.PreviousResponseID == turn1RespID {
|
|
chainModeActivated = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, chainModeActivated,
|
|
"turn 2 parent request must carry previous_response_id; without it this test does not exercise chain mode")
|
|
|
|
require.True(t, advisorCallSeen.Load(),
|
|
"the nested advisor call must execute under chain mode")
|
|
require.NotEmpty(t, gotAdvisorBody,
|
|
"advisor call must receive a non-empty request body")
|
|
|
|
// The core assertion: the advisor snapshot must retain turn 1
|
|
// context. Chain mode filtering strips assistant and tool turns
|
|
// from the prompt the outer loop sees, so if that filtered view
|
|
// leaked into the snapshot the advisor would only see turn 2's
|
|
// trailing user message. The advisor's nested call goes through
|
|
// the OpenAI Responses API, which encodes its prompt in the
|
|
// "input" field rather than "messages", so we inspect the raw
|
|
// request body for both turn-1 substrings.
|
|
require.Contains(t, string(gotAdvisorBody), turn1User,
|
|
"advisor snapshot must retain the turn 1 user message even when chain mode is active")
|
|
require.Contains(t, string(gotAdvisorBody), turn1Reply,
|
|
"advisor snapshot must retain the turn 1 assistant message even when chain mode is active")
|
|
}
|
|
|
|
func seedAdvisorConfig(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
cfg codersdk.AdvisorConfig,
|
|
) {
|
|
t.Helper()
|
|
|
|
data, err := json.Marshal(cfg)
|
|
require.NoError(t, err)
|
|
err = db.UpsertChatAdvisorConfig(
|
|
dbauthz.AsSystemRestricted(ctx),
|
|
string(data),
|
|
)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestPromoteQueuedWhileRunning guards against the data-loss
|
|
// failure mode: promoting on a streaming chat must preserve
|
|
// partial assistant output by deferring the user-message insert
|
|
// to the worker's auto-promote.
|
|
func TestPromoteQueuedWhileRunning(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
streamStarted := make(chan struct{})
|
|
streamCanceled := make(chan struct{})
|
|
var streamCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("running-promote")
|
|
}
|
|
if streamCallCount.Add(1) > 1 {
|
|
// Subsequent calls are the resumed run; let it settle.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("resumed after promotion")...,
|
|
)
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial-running-output")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
select {
|
|
case <-streamCanceled:
|
|
default:
|
|
close(streamCanceled)
|
|
}
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "promote-while-running",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queuedResult.Queued)
|
|
require.NotNil(t, queuedResult.QueuedMessage)
|
|
|
|
promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queuedResult.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
// Deferred promotion: no synchronous user message.
|
|
require.Zero(t, promoteResult.PromotedMessage.ID)
|
|
|
|
// Worker observes waiting and cancels.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamCanceled:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
// Partial assistant output is preserved (not lost as it was
|
|
// pre-fix) and precedes the promoted user message. Poll on the
|
|
// messages themselves: the status passes through Waiting
|
|
// transiently before finishActiveChat's external-Waiting case
|
|
// promotes the queued message and flips the chat to Pending.
|
|
// Both messages being persisted implies cleanup completed.
|
|
var (
|
|
partialAssistantID int64
|
|
promotedUserID int64
|
|
)
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if err != nil {
|
|
return false
|
|
}
|
|
var (
|
|
assistantID int64
|
|
userID int64
|
|
)
|
|
for _, msg := range messages {
|
|
switch msg.Role {
|
|
case database.ChatMessageRoleAssistant:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "partial-running-output") {
|
|
assistantID = msg.ID
|
|
}
|
|
}
|
|
case database.ChatMessageRoleUser:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "promote me") {
|
|
userID = msg.ID
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if assistantID == 0 || userID == 0 {
|
|
return false
|
|
}
|
|
partialAssistantID = assistantID
|
|
promotedUserID = userID
|
|
return true
|
|
}, testutil.IntervalFast)
|
|
require.Less(t, partialAssistantID, promotedUserID,
|
|
"promoted user message must follow the persisted partial output")
|
|
}
|
|
|
|
// TestPromoteQueuedWhileRunningRespectsMessageOrder guards
|
|
// against losing or reshuffling sibling queued messages when one
|
|
// is promoted out-of-order.
|
|
func TestPromoteQueuedWhileRunningRespectsMessageOrder(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
streamStarted := make(chan struct{})
|
|
var streamCallCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("running-promote-order")
|
|
}
|
|
if streamCallCount.Add(1) > 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("resumed")...,
|
|
)
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
Title: "promote-while-running-order",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
queueA, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("A")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, queueA.QueuedMessage)
|
|
queueB, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("B")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, queueB.QueuedMessage)
|
|
queueC, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("C")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, queueC.QueuedMessage)
|
|
|
|
promoteResult, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queueB.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Zero(t, promoteResult.PromotedMessage.ID,
|
|
"running-case promotion is deferred to auto-promote")
|
|
|
|
// Wait for the worker to drain all three queued messages into
|
|
// chat history, then verify ordering. Reading queue state right
|
|
// after PromoteQueued races the worker's auto-promote pipeline
|
|
// (TOCTOU), so we wait for the settled outcome instead.
|
|
var posB, posA, posC int
|
|
var foundA, foundB, foundC bool
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
messages, getErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
foundA, foundB, foundC = false, false, false
|
|
for i, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
return false
|
|
}
|
|
for _, part := range parts {
|
|
if part.Type != codersdk.ChatMessagePartTypeText {
|
|
continue
|
|
}
|
|
// Only A, B, C are tracked; other user messages are ignored.
|
|
switch part.Text {
|
|
case "A":
|
|
posA = i
|
|
foundA = true
|
|
case "B":
|
|
posB = i
|
|
foundB = true
|
|
case "C":
|
|
posC = i
|
|
foundC = true
|
|
}
|
|
}
|
|
}
|
|
return foundA && foundB && foundC
|
|
}, testutil.IntervalFast,
|
|
"queued messages not found in chat history: foundA=%v, foundB=%v, foundC=%v", foundA, foundB, foundC)
|
|
|
|
// PromoteQueued reorders the queue to [B, A, C], so the worker
|
|
// processes B first, then A, then C. Verify that ordering.
|
|
require.Less(t, posB, posA,
|
|
"promoted message B must appear before A in history")
|
|
require.Less(t, posA, posC,
|
|
"non-promoted messages must preserve relative order (A before C)")
|
|
}
|
|
|
|
// TestFinishActiveChatExternalWaitingInsertsSyntheticResults
|
|
// asserts the cleanup TX inserts synthetic tool-result rows when
|
|
// PromoteQueued's deferred path set Status=Waiting while the
|
|
// worker concluded with RequiresAction. Without it, the next
|
|
// chatloop run would feed the LLM an assistant turn with
|
|
// unresolved tool_call parts and the API would reject it.
|
|
func TestFinishActiveChatExternalWaitingInsertsSyntheticResults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
server := newActiveTestServer(t, db, ps)
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "external-waiting-stops-dead-guard",
|
|
LastModelConfigID: model.ID,
|
|
DynamicTools: nullRawMessage(dynamicToolsJSON),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Seed a user message and an assistant message with an
|
|
// unresolved dynamic tool call. This mirrors what the worker
|
|
// would have persisted before the deferred promote arrived.
|
|
insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input")
|
|
|
|
pendingCallID := "call_pending_dynamic"
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: pendingCallID,
|
|
ToolName: "my_dynamic_tool",
|
|
Args: json.RawMessage(`{}`),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{model.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Content: []string{string(assistantContent.RawMessage)},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Queue a message and put the chat in the post-promote
|
|
// Waiting state (no worker, queue at front).
|
|
queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("queued-after-promote"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent.RawMessage,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Refresh chat with current status (Waiting, no worker).
|
|
latestChat, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Drive the cleanup path with the local-RequiresAction outcome.
|
|
updated, promoted, syntheticToolResults, finishErr := chatd.FinishActiveChatForTest(
|
|
ctx, server, latestChat, database.ChatStatusRequiresAction, "",
|
|
)
|
|
require.NoError(t, finishErr)
|
|
require.NotNil(t, promoted, "queued message must be auto-promoted into history")
|
|
require.Equal(t, database.ChatStatusPending, updated.Status,
|
|
"chat must end Pending so the run loop picks it up")
|
|
require.Len(t, syntheticToolResults, 1,
|
|
"cleanup TX must return the inserted synthetic tool-result row so the post-TX caller can publish it")
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
assistantIdx = -1
|
|
synthToolIdx = -1
|
|
promotedUserIdx = -1
|
|
)
|
|
for i, msg := range messages {
|
|
switch msg.Role {
|
|
case database.ChatMessageRoleAssistant:
|
|
assistantIdx = i
|
|
case database.ChatMessageRoleTool:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
|
|
part.ToolCallID == pendingCallID && part.IsError {
|
|
synthToolIdx = i
|
|
}
|
|
}
|
|
case database.ChatMessageRoleUser:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText &&
|
|
part.Text == "queued-after-promote" {
|
|
promotedUserIdx = i
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present")
|
|
require.NotEqual(t, -1, synthToolIdx,
|
|
"synthetic tool result for the unresolved dynamic tool call must be inserted")
|
|
require.NotEqual(t, -1, promotedUserIdx,
|
|
"promoted queued message must be inserted as a user message")
|
|
require.Less(t, assistantIdx, synthToolIdx,
|
|
"synthetic tool result must follow the assistant message")
|
|
require.Less(t, synthToolIdx, promotedUserIdx,
|
|
"promoted user message must follow the synthetic tool result")
|
|
}
|
|
|
|
// TestPromoteQueuedFallsThroughOnStaleHeartbeat asserts a stale
|
|
// heartbeat takes the synchronous path so the chat does not strand
|
|
// in Waiting waiting on a worker that will not return.
|
|
func TestPromoteQueuedFallsThroughOnStaleHeartbeat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
staleAfter := 100 * time.Millisecond
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() { require.NoError(t, server.Close()) })
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "stale-heartbeat-promote-fallthrough",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Place the chat in Running with a stale heartbeat. We do not
|
|
// start the server's run loop, so no worker will ever pick this
|
|
// chat up; the test isolates the fall-through decision in
|
|
// PromoteQueued.
|
|
deadWorker := uuid.New()
|
|
staleTime := time.Now().Add(-2 * staleAfter)
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorker, Valid: true},
|
|
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("promote me")},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, queued.Queued)
|
|
require.NotNil(t, queued.QueuedMessage)
|
|
|
|
result, err := server.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
|
|
ChatID: chat.ID,
|
|
QueuedMessageID: queued.QueuedMessage.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotZero(t, result.PromotedMessage.ID,
|
|
"stale heartbeat must take the synchronous path and insert a user message inline")
|
|
|
|
got, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, got.Status,
|
|
"synchronous promote ends Pending")
|
|
require.False(t, got.WorkerID.Valid,
|
|
"worker_id is cleared by the synchronous promote")
|
|
}
|
|
|
|
// TestRecoverStaleChatsRecoversWaitingWithQueue asserts a Waiting
|
|
// chat with a non-empty queue and stale updated_at gets recovered
|
|
// to Pending, closing the post-promote-stranding hole.
|
|
func TestRecoverStaleChatsRecoversWaitingWithQueue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
staleAfter := 100 * time.Millisecond
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() { require.NoError(t, server.Close()) })
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "stale-waiting-with-queue",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("queued-stranded"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent.RawMessage,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
// Backdate updated_at directly so the chat is past the stale
|
|
// threshold without sleeping.
|
|
_, err = rawDB.ExecContext(ctx,
|
|
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
|
time.Now().Add(-time.Hour), chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
chatd.RecoverStaleChatsForTest(ctx, server)
|
|
|
|
got, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, got.Status,
|
|
"stale-recovery must promote the front-of-queue and set Pending")
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
var foundPromoted bool
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText &&
|
|
part.Text == "queued-stranded" {
|
|
foundPromoted = true
|
|
}
|
|
}
|
|
}
|
|
require.True(t, foundPromoted,
|
|
"the front-of-queue message must be promoted into history")
|
|
|
|
remaining, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, remaining,
|
|
"the queue is drained after the recovery promotes its only entry")
|
|
}
|
|
|
|
// TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults
|
|
// asserts stale recovery closes pending dynamic tool calls before
|
|
// promoting, so the recovery path does not stop the chat dead by
|
|
// feeding the LLM unresolved tool_call parts.
|
|
func TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
staleAfter := 100 * time.Millisecond
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() { require.NoError(t, server.Close()) })
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{
|
|
Type: "object",
|
|
Properties: map[string]any{},
|
|
},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "stale-waiting-with-unresolved-tool-call",
|
|
LastModelConfigID: model.ID,
|
|
DynamicTools: nullRawMessage(dynamicToolsJSON),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call the tool")
|
|
|
|
pendingCallID := "call_unresolved_dynamic"
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: pendingCallID,
|
|
ToolName: "my_dynamic_tool",
|
|
Args: json.RawMessage(`{}`),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{model.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Content: []string{string(assistantContent.RawMessage)},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("queued-after-crash"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent.RawMessage,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = rawDB.ExecContext(ctx,
|
|
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
|
time.Now().Add(-time.Hour), chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
chatd.RecoverStaleChatsForTest(ctx, server)
|
|
|
|
got, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, got.Status)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
assistantIdx = -1
|
|
synthIdx = -1
|
|
promotedUserIdx = -1
|
|
)
|
|
for i, msg := range messages {
|
|
switch msg.Role {
|
|
case database.ChatMessageRoleAssistant:
|
|
assistantIdx = i
|
|
case database.ChatMessageRoleTool:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
|
|
part.ToolCallID == pendingCallID && part.IsError {
|
|
synthIdx = i
|
|
}
|
|
}
|
|
case database.ChatMessageRoleUser:
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeText &&
|
|
part.Text == "queued-after-crash" {
|
|
promotedUserIdx = i
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present")
|
|
require.NotEqual(t, -1, synthIdx,
|
|
"stale recovery must insert synthetic tool result for the unresolved dynamic tool call")
|
|
require.NotEqual(t, -1, promotedUserIdx,
|
|
"queued message must be promoted into history")
|
|
require.Less(t, assistantIdx, synthIdx)
|
|
require.Less(t, synthIdx, promotedUserIdx)
|
|
}
|
|
|
|
// TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls asserts
|
|
// the helper skips tool calls already handled (e.g. when a dynamic
|
|
// tool name collides with a built-in the chatloop dispatched).
|
|
// Without dedup the LLM would see two results for the same call ID.
|
|
func TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{
|
|
{
|
|
Name: "duplicate_call_tool",
|
|
Description: "Tool whose call already has a result.",
|
|
InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}},
|
|
},
|
|
{
|
|
Name: "still_pending_tool",
|
|
Description: "Tool whose call has no result yet.",
|
|
InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusRequiresAction,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "synth-results-dedup",
|
|
LastModelConfigID: model.ID,
|
|
DynamicTools: nullRawMessage(dynamicToolsJSON),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call both tools")
|
|
|
|
handledCallID := "call_already_handled"
|
|
pendingCallID := "call_still_pending"
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: handledCallID,
|
|
ToolName: "duplicate_call_tool",
|
|
Args: json.RawMessage(`{}`),
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: pendingCallID,
|
|
ToolName: "still_pending_tool",
|
|
Args: json.RawMessage(`{}`),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{model.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Content: []string{string(assistantContent.RawMessage)},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Pre-insert a tool-result for the handled call ID. This
|
|
// simulates the chatloop having dispatched the colliding
|
|
// dynamic tool name as a built-in.
|
|
handledResultContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeToolResult,
|
|
ToolCallID: handledCallID,
|
|
ToolName: "duplicate_call_tool",
|
|
Result: json.RawMessage(`"already done"`),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{model.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleTool},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Content: []string{string(handledResultContent.RawMessage)},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatRow, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = chatd.InsertSyntheticToolResultsTxForTest(
|
|
ctx, db, chatRow, "synth reason",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
handledCount int
|
|
pendingCount int
|
|
syntheticForPending bool
|
|
)
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleTool {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
require.NoError(t, parseErr)
|
|
for _, part := range parts {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult {
|
|
continue
|
|
}
|
|
switch part.ToolCallID {
|
|
case handledCallID:
|
|
handledCount++
|
|
case pendingCallID:
|
|
pendingCount++
|
|
if part.IsError {
|
|
syntheticForPending = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.Equal(t, 1, handledCount,
|
|
"handled call must keep exactly one tool result")
|
|
require.Equal(t, 1, pendingCount,
|
|
"pending call must get exactly one synthetic tool result")
|
|
require.True(t, syntheticForPending,
|
|
"the new tool result for the pending call must be marked IsError")
|
|
}
|
|
|
|
// nullRawMessage wraps raw JSON in a NullRawMessage. An empty input
|
|
// becomes the zero value (Valid=false).
|
|
func nullRawMessage(raw []byte) pqtype.NullRawMessage {
|
|
if len(raw) == 0 {
|
|
return pqtype.NullRawMessage{}
|
|
}
|
|
return pqtype.NullRawMessage{RawMessage: raw, Valid: true}
|
|
}
|
|
|
|
// TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage
|
|
// asserts the helper short-circuits cleanly when no assistant
|
|
// message exists yet, so a deferred promote racing a worker that
|
|
// fails before any persist does not roll back the cleanup TX.
|
|
func TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "no-assistant-message",
|
|
LastModelConfigID: model.ID,
|
|
DynamicTools: nullRawMessage(dynamicToolsJSON),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// No assistant message persisted. The helper must return nil so
|
|
// the caller's transaction can still advance.
|
|
_, err = chatd.InsertSyntheticToolResultsTxForTest(
|
|
ctx, db, chat, "no assistant",
|
|
)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// TestRecoverStaleChatsWaitingPropagatesSynthError asserts stale
|
|
// recovery rolls back when synth-result insertion fails, leaving
|
|
// the chat Waiting for the next tick instead of promoting on top
|
|
// of incomplete history.
|
|
func TestRecoverStaleChatsWaitingPropagatesSynthError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
staleAfter := 100 * time.Millisecond
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() { require.NoError(t, server.Close()) })
|
|
|
|
user, org, model := seedChatDependencies(t, db)
|
|
|
|
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
|
Name: "my_dynamic_tool",
|
|
Description: "A test dynamic tool.",
|
|
InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}},
|
|
}})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
Title: "stale-waiting-synth-error",
|
|
LastModelConfigID: model.ID,
|
|
DynamicTools: nullRawMessage(dynamicToolsJSON),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input")
|
|
|
|
// Inject a synth-results error via an unsupported
|
|
// ContentVersion: the row is valid JSON so the insert
|
|
// succeeds, but chatprompt.ParseContent rejects it inside the
|
|
// helper. Brittle if a future migration adds a content_version
|
|
// CHECK constraint; switch to a mock store at that point.
|
|
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{model.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
ContentVersion: []int16{99},
|
|
Content: []string{`{}`},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("queued-not-promoted-on-synth-error"),
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: queuedContent.RawMessage,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = rawDB.ExecContext(ctx,
|
|
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
|
time.Now().Add(-time.Hour), chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
chatd.RecoverStaleChatsForTest(ctx, server)
|
|
|
|
got, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, got.Status,
|
|
"recovery must leave the chat in Waiting when synth-results fails so the next tick retries")
|
|
|
|
// The queued message must still be in the queue, not promoted.
|
|
remaining, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, 1,
|
|
"queued message must not be promoted when synth-results fails")
|
|
|
|
// No promoted user message should appear in history.
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, msg := range messages {
|
|
if msg.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
parts, parseErr := chatprompt.ParseContent(msg)
|
|
if parseErr != nil {
|
|
continue
|
|
}
|
|
for _, part := range parts {
|
|
require.NotEqual(t, "queued-not-promoted-on-synth-error", part.Text,
|
|
"queued message must not be promoted when synth-results fails")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Regression for the cold-start race: chatd must wait long enough
|
|
// for ListMCPTools to return after the agent's MCP reload settles.
|
|
func TestRunChat_WorkspaceMCPDiscoveryWaitsForSlowAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const slowAgentMCPListDelay = 7 * time.Second
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
requestsMu.Unlock()
|
|
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
|
|
|
workspaceToolName := "workspace-slow-mcp__echo"
|
|
workspaceToolsResp := workspacesdk.ListMCPToolsResponse{
|
|
Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-slow-mcp",
|
|
Name: workspaceToolName,
|
|
Description: "Slow workspace echo tool",
|
|
Schema: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
}},
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
// Honor ctx so the goroutine exits if chatd cancels.
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
DoAndReturn(func(ctx context.Context) (workspacesdk.ListMCPToolsResponse, error) {
|
|
select {
|
|
case <-time.After(slowAgentMCPListDelay):
|
|
return workspaceToolsResp, nil
|
|
case <-ctx.Done():
|
|
return workspacesdk.ListMCPToolsResponse{}, ctx.Err()
|
|
}
|
|
}).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "workspace-mcp-slow-agent",
|
|
ModelConfigID: model.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("List the workspace MCP tools."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q",
|
|
chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.Len(t, recorded, 1, "expected exactly one streamed model call")
|
|
require.Contains(t, recorded[0].Tools, workspaceToolName,
|
|
"workspace MCP tool should reach the LLM once chatd's discovery "+
|
|
"timeout exceeds the agent's MCP reload time")
|
|
}
|
|
|
|
// TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace guards the
|
|
// regression where chats that bound their workspace mid-turn (via
|
|
// create_workspace) never saw workspace MCP tools on the same turn. The
|
|
// chatloop tool list was frozen at the top of the turn, so the first
|
|
// post-create_workspace step had no workspace MCP tools and the model
|
|
// fell back to bash. See PrepareTools wiring in runChat.
|
|
func TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
|
|
workspaceToolName := "workspace-midturn-mcp__echo"
|
|
workspaceCreateToolArgsJSON := ""
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
callIdx := len(requests)
|
|
requestsMu.Unlock()
|
|
|
|
if callIdx == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Seed a workspace+agent for create_workspace to bind to.
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
workspaceCreateToolArgsJSON = fmt.Sprintf(`{"template_id":%q}`, tpl.ID.String())
|
|
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
TemplateID: tpl.ID,
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
InitiatorID: user.ID,
|
|
OrganizationID: org.ID,
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: ws.ID,
|
|
JobID: pj.ID,
|
|
})
|
|
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
Transition: database.WorkspaceTransitionStart,
|
|
JobID: pj.ID,
|
|
})
|
|
now := dbtime.Now()
|
|
dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: res.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
StartedAt: sql.NullTime{Time: now, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: now, Valid: true},
|
|
FirstConnectedAt: sql.NullTime{Time: now, Valid: true},
|
|
LastConnectedAt: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
|
|
workspaceToolsResp := workspacesdk.ListMCPToolsResponse{
|
|
Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-midturn-mcp",
|
|
Name: workspaceToolName,
|
|
Description: "workspace echo tool",
|
|
Schema: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
}},
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspaceToolsResp, nil).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
mockConn.EXPECT().AwaitReachable(gomock.Any()).Return(true).AnyTimes()
|
|
|
|
createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
|
return codersdk.Workspace{
|
|
ID: ws.ID,
|
|
Name: req.Name,
|
|
OwnerName: user.Username,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
LatestBuild: codersdk.WorkspaceBuild{
|
|
ID: build.ID,
|
|
Status: codersdk.WorkspaceStatusRunning,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
cfg.CreateWorkspace = createFn
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "workspace-mcp-midturn",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Create a workspace and call the workspace MCP tool."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q",
|
|
chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recorded), 2,
|
|
"expected at least two streamed model calls (create_workspace + follow-up)")
|
|
require.NotContains(t, recorded[0].Tools, workspaceToolName,
|
|
"first call should not advertise workspace MCP tools because the chat has no workspace yet")
|
|
require.Contains(t, recorded[1].Tools, workspaceToolName,
|
|
"second call (after create_workspace) must advertise the workspace MCP tool: "+
|
|
"this is the fix for mid-turn workspace MCP discovery")
|
|
}
|
|
|
|
// TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery guards the
|
|
// regression on the workspaceMCPDiscovered flag flip: the prior
|
|
// implementation set the flag to true before calling
|
|
// discoverWorkspaceMCPTools, so a single empty result permanently
|
|
// blocked retries within the turn. The fix sets the flag to true
|
|
// only after a non-empty discovery, so subsequent PrepareTools
|
|
// invocations keep retrying until tools appear.
|
|
//
|
|
// Scenario: create_workspace binds a workspace mid-turn. The first
|
|
// few ListMCPTools calls return empty (simulating the agent's MCP
|
|
// Connect still racing with agent startup); a later call returns
|
|
// the workspace MCP tool. The chat takes multiple steps before
|
|
// finishing, and we assert that one of the post-create_workspace
|
|
// streamed model calls advertises the workspace tool.
|
|
func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var (
|
|
requestsMu sync.Mutex
|
|
requests []recordedOpenAIRequest
|
|
)
|
|
|
|
workspaceToolName := "workspace-empty-retry-mcp__echo"
|
|
workspaceCreateToolArgsJSON := ""
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
|
|
requestsMu.Lock()
|
|
requests = append(requests, recordOpenAIRequest(req))
|
|
callIdx := len(requests)
|
|
requestsMu.Unlock()
|
|
|
|
// Step 1: trigger create_workspace.
|
|
if callIdx == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON),
|
|
)
|
|
}
|
|
// Step 2..N-1: emit empty text to keep the chatloop running so
|
|
// PrepareTools fires on each step. The chatloop ends a turn
|
|
// when the model returns a non-empty assistant message with no
|
|
// tool calls; an empty text chunk would terminate the turn, so
|
|
// we attach a dummy tool call to force another step. Use the
|
|
// LS tool because it exists for all workspaces and is cheap.
|
|
if callIdx < 6 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("ls", `{"path":"/tmp"}`),
|
|
)
|
|
}
|
|
// Final step: finish the chat.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL)
|
|
|
|
// Seed a workspace+agent for create_workspace to bind to.
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
workspaceCreateToolArgsJSON = fmt.Sprintf(`{"template_id":%q}`, tpl.ID.String())
|
|
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
TemplateID: tpl.ID,
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
InitiatorID: user.ID,
|
|
OrganizationID: org.ID,
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: ws.ID,
|
|
JobID: pj.ID,
|
|
})
|
|
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
Transition: database.WorkspaceTransitionStart,
|
|
JobID: pj.ID,
|
|
})
|
|
now := dbtime.Now()
|
|
dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: res.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
StartedAt: sql.NullTime{Time: now, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: now, Valid: true},
|
|
FirstConnectedAt: sql.NullTime{Time: now, Valid: true},
|
|
LastConnectedAt: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
|
|
workspaceToolsResp := workspacesdk.ListMCPToolsResponse{
|
|
Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-empty-retry-mcp",
|
|
Name: workspaceToolName,
|
|
Description: "workspace echo tool",
|
|
Schema: map[string]any{
|
|
"input": map[string]any{"type": "string"},
|
|
},
|
|
Required: []string{"input"},
|
|
}},
|
|
}
|
|
|
|
// First two ListMCPTools calls return empty (no error). One is the
|
|
// primer goroutine's only attempt before its retry timer fires;
|
|
// the other is PrepareTools on the first post-create_workspace
|
|
// step. The third and later calls return the workspace tool. The
|
|
// assertion below requires that a post-create_workspace step
|
|
// eventually advertises the tool, which can only happen if the
|
|
// PrepareTools callback retries discovery on subsequent steps.
|
|
var listCalls atomic.Int32
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
mockConn.EXPECT().ContextConfig(gomock.Any()).
|
|
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
|
|
mockConn.EXPECT().ListMCPTools(gomock.Any()).DoAndReturn(
|
|
func(context.Context) (workspacesdk.ListMCPToolsResponse, error) {
|
|
n := listCalls.Add(1)
|
|
if n <= 2 {
|
|
return workspacesdk.ListMCPToolsResponse{}, nil
|
|
}
|
|
return workspaceToolsResp, nil
|
|
},
|
|
).AnyTimes()
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
|
|
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
|
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
|
|
mockConn.EXPECT().AwaitReachable(gomock.Any()).Return(true).AnyTimes()
|
|
|
|
createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
|
return codersdk.Workspace{
|
|
ID: ws.ID,
|
|
Name: req.Name,
|
|
OwnerName: user.Username,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
LatestBuild: codersdk.WorkspaceBuild{
|
|
ID: build.ID,
|
|
Status: codersdk.WorkspaceStatusRunning,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
|
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, dbAgent.ID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
cfg.CreateWorkspace = createFn
|
|
})
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Title: "workspace-mcp-empty-retry",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText("Create a workspace and call the workspace MCP tool."),
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
|
if chatResult.Status == database.ChatStatusError {
|
|
require.FailNowf(t, "chat failed", "last_error=%q",
|
|
chatLastErrorMessage(chatResult.LastError))
|
|
}
|
|
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
|
|
|
requestsMu.Lock()
|
|
recorded := append([]recordedOpenAIRequest(nil), requests...)
|
|
requestsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recorded), 3,
|
|
"expected at least three streamed model calls; chat must run past the empty discovery")
|
|
|
|
// The first call has no workspace yet; the second call is the
|
|
// first post-create_workspace step which sees an empty
|
|
// ListMCPTools result. By the third (or later) call PrepareTools
|
|
// must have retried discovery, so at least one post-step request
|
|
// must advertise the workspace tool. Without the
|
|
// workspaceMCPDiscovered flag-flip fix the flag would have been
|
|
// set true on the failed first attempt and no subsequent step
|
|
// would have re-attempted discovery.
|
|
sawWorkspaceTool := false
|
|
for i := 2; i < len(recorded); i++ {
|
|
if slices.Contains(recorded[i].Tools, workspaceToolName) {
|
|
sawWorkspaceTool = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, sawWorkspaceTool,
|
|
"PrepareTools must retry workspace MCP discovery on subsequent "+
|
|
"steps; without the fix the first empty result would "+
|
|
"permanently block retries within the turn")
|
|
}
|