mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
1993 lines
64 KiB
Go
1993 lines
64 KiB
Go
package chatd_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/agent/agenttest"
|
|
"github.com/coder/coder/v2/coderd/chatd"
|
|
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
|
"github.com/coder/coder/v2/provisioner/echo"
|
|
proto "github.com/coder/coder/v2/provisionersdk/proto"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replicaA := newTestServer(t, db, ps, uuid.New())
|
|
replicaB := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-me",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
runningWorker := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
updated := replicaA.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
|
return false
|
|
}
|
|
return event.Status.Status == codersdk.ChatStatusWaiting
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := coderdtest.DeploymentValues(t)
|
|
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
// Track tools sent in LLM requests. The first call is for the
|
|
// root chat which spawns a subagent; the second call is for the
|
|
// subagent itself.
|
|
var toolsMu sync.Mutex
|
|
toolsByCall := make([][]string, 0, 2)
|
|
|
|
var callCount atomic.Int32
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("ok")
|
|
}
|
|
|
|
names := make([]string, 0, len(req.Tools))
|
|
for _, tool := range req.Tools {
|
|
names = append(names, tool.Function.Name)
|
|
}
|
|
toolsMu.Lock()
|
|
toolsByCall = append(toolsByCall, names)
|
|
toolsMu.Unlock()
|
|
|
|
if callCount.Add(1) == 1 {
|
|
// Root chat: model calls spawn_agent.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("spawn_agent", `{"prompt":"do the thing","title":"sub"}`),
|
|
)
|
|
}
|
|
// Subsequent calls (including the subagent): just reply.
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Done.")...,
|
|
)
|
|
})
|
|
|
|
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
|
Provider: "openai-compat",
|
|
APIKey: "test-api-key",
|
|
BaseURL: openAIURL,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
contextLimit := int64(4096)
|
|
isDefault := true
|
|
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
|
Provider: "openai-compat",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: &contextLimit,
|
|
IsDefault: &isDefault,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create a root chat whose first model call will spawn a subagent.
|
|
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Spawn a subagent to do the thing.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the root chat AND the subagent to finish.
|
|
// The root chat finishes first, then the chatd server
|
|
// picks up and runs the child (subagent) chat.
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := client.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
if got.Chat.Status != codersdk.ChatStatusWaiting && got.Chat.Status != codersdk.ChatStatusError {
|
|
return false
|
|
}
|
|
// Also ensure the subagent LLM call has been made.
|
|
toolsMu.Lock()
|
|
n := len(toolsByCall)
|
|
toolsMu.Unlock()
|
|
// Expect at least 3 calls: root-1 (spawn_agent), child-1, root-2.
|
|
return n >= 3
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
// There should be at least two streamed calls: one for the root
|
|
// chat and one for the subagent child chat.
|
|
toolsMu.Lock()
|
|
recorded := append([][]string(nil), toolsByCall...)
|
|
toolsMu.Unlock()
|
|
|
|
require.GreaterOrEqual(t, len(recorded), 2,
|
|
"expected at least 2 streamed LLM calls (root + subagent)")
|
|
|
|
workspaceTools := []string{"list_templates", "read_template", "create_workspace"}
|
|
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
|
|
|
|
// Identify root and subagent calls. Root chat calls include
|
|
// spawn_agent; the subagent call does not. Because the root chat
|
|
// makes multiple LLM calls (before and after spawn_agent), we
|
|
// find exactly one call that lacks spawn_agent — that's the
|
|
// subagent.
|
|
var rootCalls, childCalls [][]string
|
|
for _, tools := range recorded {
|
|
hasSpawnAgent := slice.Contains(tools, "spawn_agent")
|
|
if hasSpawnAgent {
|
|
rootCalls = append(rootCalls, tools)
|
|
} else {
|
|
childCalls = append(childCalls, tools)
|
|
}
|
|
}
|
|
|
|
require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call")
|
|
require.NotEmpty(t, childCalls, "expected at least one subagent LLM call")
|
|
|
|
// Root chat calls must include workspace and subagent tools.
|
|
for _, tool := range workspaceTools {
|
|
require.Contains(t, rootCalls[0], tool,
|
|
"root chat should have workspace tool %q", tool)
|
|
}
|
|
for _, tool := range subagentTools {
|
|
require.Contains(t, rootCalls[0], tool,
|
|
"root chat should have subagent tool %q", tool)
|
|
}
|
|
|
|
// Subagent calls must NOT include workspace or subagent tools.
|
|
for _, tool := range workspaceTools {
|
|
require.NotContains(t, childCalls[0], tool,
|
|
"subagent chat should NOT have workspace tool %q", tool)
|
|
}
|
|
for _, tool := range subagentTools {
|
|
require.NotContains(t, childCalls[0], tool,
|
|
"subagent chat should NOT have subagent tool %q", tool)
|
|
}
|
|
}
|
|
|
|
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "db-transition",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
updated := replica.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "heartbeat-ownership",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
|
ID: chat.ID,
|
|
WorkerID: uuid.New(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(0), rows)
|
|
|
|
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
|
ID: chat.ID,
|
|
WorkerID: workerID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), rows)
|
|
}
|
|
|
|
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "queue-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Queued)
|
|
require.NotNil(t, result.QueuedMessage)
|
|
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
|
|
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
|
|
require.True(t, result.Chat.WorkerID.Valid)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 1)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
}
|
|
|
|
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, result.Queued)
|
|
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
|
require.False(t, result.Chat.WorkerID.Valid)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 0)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 2)
|
|
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
|
}
|
|
|
|
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "edit-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "original"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initialMessages, 1)
|
|
editedMessageID := initialMessages[0].ID
|
|
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "follow-up"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "another"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: json.RawMessage(`"queued"`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMessageID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, editedMessageID, editResult.Message.ID)
|
|
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
|
|
require.False(t, editResult.Chat.WorkerID.Valid)
|
|
|
|
editedSDK := db2sdk.ChatMessage(editResult.Message)
|
|
require.Len(t, editedSDK.Content, 1)
|
|
require.Equal(t, "edited", editedSDK.Content[0].Text)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
require.Equal(t, editedMessageID, messages[0].ID)
|
|
onlyMessage := db2sdk.ChatMessage(messages[0])
|
|
require.Len(t, onlyMessage.Content, 1)
|
|
require.Equal(t, "edited", onlyMessage.Content[0].Text)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 0)
|
|
|
|
chatFromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, chatFromDB.Status)
|
|
require.False(t, chatFromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "missing-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: 999999,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
|
|
}
|
|
|
|
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "non-user-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
assistantMessage, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "assistant",
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`"assistant"`),
|
|
Valid: true,
|
|
},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: assistantMessage.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
|
|
}
|
|
|
|
func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Use a very short stale threshold so the periodic recovery
|
|
// kicks in quickly during the test.
|
|
staleAfter := 500 * time.Millisecond
|
|
|
|
// Create a chat and simulate a dead worker by setting the chat
|
|
// to running with a heartbeat in the past.
|
|
deadWorkerID := uuid.New()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica. Its startup recovery will reset the
|
|
// chat (since the heartbeat is old), but the key point is that
|
|
// the periodic loop also recovers newly-stale chats.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// The startup recovery should have already reset our stale
|
|
// chat.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
// Now simulate a second stale chat appearing AFTER startup.
|
|
// This tests the periodic recovery, not just the startup one.
|
|
deadWorkerID2 := uuid.New()
|
|
chat2, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic-2",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat2.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID2, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The periodic stale recovery loop (running at staleAfter/5 =
|
|
// 100ms intervals) should pick this up without a restart.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat2.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Simulate a chat left running by a dead replica with a stale
|
|
// heartbeat (well beyond the stale threshold).
|
|
deadReplicaID := uuid.New()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "orphaned-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Set the heartbeat far in the past so it's definitely stale.
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadReplicaID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica — it should recover the stale chat on
|
|
// startup.
|
|
newReplica := newTestServer(t, db, ps, uuid.New())
|
|
_ = newReplica
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat in waiting status — this should NOT be touched
|
|
// by stale recovery.
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "waiting-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a replica with a short stale threshold.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
InFlightChatStaleAfter: 500 * time.Millisecond,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Wait long enough for multiple periodic recovery cycles to
|
|
// run (staleAfter/5 = 100ms intervals).
|
|
require.Never(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status != database.ChatStatusWaiting
|
|
}, time.Second, testutil.IntervalFast,
|
|
"waiting chat should not be modified by stale recovery")
|
|
}
|
|
|
|
func TestUpdateChatStatusPersistsLastError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
_ = newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "error-persisted",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Simulate a chat that failed with an error.
|
|
errorMessage := "stream response: status 500: internal server error"
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusError,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{String: errorMessage, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, chat.Status)
|
|
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, chat.LastError)
|
|
|
|
// Verify the error is persisted when re-read from the database.
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, fromDB.Status)
|
|
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, fromDB.LastError)
|
|
|
|
// Verify the error is cleared when the chat transitions to a
|
|
// non-error status (e.g. pending after a retry).
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, chat.Status)
|
|
require.False(t, chat.LastError.Valid)
|
|
|
|
fromDB, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fromDB.LastError.Valid)
|
|
}
|
|
|
|
func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "status-snapshot",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// The first event in the snapshot must be a status event.
|
|
require.NotEmpty(t, snapshot)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
|
|
require.NotNil(t, snapshot[0].Status)
|
|
require.Equal(t, codersdk.ChatStatusPending, snapshot[0].Status.Status)
|
|
}
|
|
|
|
func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use nil pubsub to force the no-pubsub path.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, nil, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "no-dup-parts",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Snapshot should have events (at minimum: status + message).
|
|
require.NotEmpty(t, snapshot)
|
|
|
|
// The events channel should NOT immediately produce any
|
|
// events — the snapshot already contained everything. Before
|
|
// the fix, localSnapshot was replayed into the channel,
|
|
// causing duplicates.
|
|
select {
|
|
case event, ok := <-events:
|
|
if ok {
|
|
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
|
|
}
|
|
// Channel closed without events is fine.
|
|
case <-time.After(200 * time.Millisecond):
|
|
// No events — correct behavior.
|
|
}
|
|
}
|
|
|
|
func TestSubscribeAfterMessageID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat — this inserts one initial "user" message.
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "after-id-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert two more messages so we have three total visible
|
|
// messages (the initial user message plus these two).
|
|
msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "assistant",
|
|
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "user",
|
|
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Control: Subscribe with afterMessageID=0 returns ALL messages.
|
|
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
cancelAll()
|
|
|
|
allMessages := filterMessageEvents(allSnapshot)
|
|
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
|
|
|
|
// Subscribe with afterMessageID set to the second message's ID.
|
|
// Only the third message (inserted after msg2) should appear.
|
|
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
|
|
require.True(t, ok)
|
|
cancelPartial()
|
|
|
|
partialMessages := filterMessageEvents(partialSnapshot)
|
|
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
|
|
require.Equal(t, "user", partialMessages[0].Message.Role)
|
|
}
|
|
|
|
// filterMessageEvents returns only the Message-type events from a
|
|
// snapshot slice, which is useful for ignoring status / queue events.
|
|
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
|
|
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
|
|
return e.Type == codersdk.ChatStreamEventTypeMessage
|
|
})
|
|
}
|
|
|
|
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := coderdtest.DeploymentValues(t)
|
|
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
|
|
agentToken := uuid.NewString()
|
|
// Add a startup script so the agent spends time in the
|
|
// "starting" lifecycle state. This lets us verify that
|
|
// create_workspace waits for scripts to finish.
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken, func(g *proto.GraphComplete) {
|
|
g.Resources[0].Agents[0].Scripts = []*proto.Script{{
|
|
DisplayName: "setup",
|
|
Script: "sleep 5",
|
|
RunOnStart: true,
|
|
}}
|
|
}),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
// Start the test workspace agent so create_workspace can wait for
|
|
// the agent to become reachable before returning.
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
workspaceName := "chat-ws-" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
|
|
createWorkspaceArgs := fmt.Sprintf(
|
|
`{"template_id":%q,"name":%q}`,
|
|
template.ID.String(),
|
|
workspaceName,
|
|
)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Create workspace test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace", createWorkspaceArgs),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Workspace created and ready.")...,
|
|
)
|
|
})
|
|
|
|
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
|
Provider: "openai-compat",
|
|
APIKey: "test-api-key",
|
|
BaseURL: openAIURL,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
contextLimit := int64(4096)
|
|
isDefault := true
|
|
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
|
Provider: "openai-compat",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: &contextLimit,
|
|
IsDefault: &isDefault,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Create a workspace from the template and continue.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatWithMessages codersdk.ChatWithMessages
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := client.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatWithMessages = got
|
|
return got.Chat.Status == codersdk.ChatStatusWaiting || got.Chat.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatWithMessages.Chat.Status == codersdk.ChatStatusError {
|
|
lastError := ""
|
|
if chatWithMessages.Chat.LastError != nil {
|
|
lastError = *chatWithMessages.Chat.LastError
|
|
}
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
|
}
|
|
|
|
require.NotNil(t, chatWithMessages.Chat.WorkspaceID)
|
|
workspaceID := *chatWithMessages.Chat.WorkspaceID
|
|
workspace, err := client.Workspace(ctx, workspaceID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceName, workspace.Name)
|
|
|
|
var foundCreateWorkspaceResult bool
|
|
for _, message := range chatWithMessages.Messages {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
for _, part := range message.Content {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "create_workspace" {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
require.NoError(t, json.Unmarshal(part.Result, &result))
|
|
created, ok := result["created"].(bool)
|
|
require.True(t, ok)
|
|
require.True(t, created)
|
|
foundCreateWorkspaceResult = true
|
|
}
|
|
}
|
|
require.True(t, foundCreateWorkspaceResult, "expected create_workspace tool result message")
|
|
|
|
// Verify that the tool waited for startup scripts to
|
|
// complete. The agent should be in "ready" state by the
|
|
// time create_workspace returns its result.
|
|
workspace, err = client.Workspace(ctx, workspaceID)
|
|
require.NoError(t, err)
|
|
var agentLifecycle codersdk.WorkspaceAgentLifecycle
|
|
for _, res := range workspace.LatestBuild.Resources {
|
|
for _, agt := range res.Agents {
|
|
agentLifecycle = agt.LifecycleState
|
|
}
|
|
}
|
|
require.Equal(t, codersdk.WorkspaceAgentLifecycleReady, agentLifecycle,
|
|
"agent should be ready after create_workspace returns; startup scripts were not awaited")
|
|
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
created, ok := result["created"].(bool)
|
|
if ok && created {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include create_workspace tool output")
|
|
}
|
|
|
|
func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
|
deploymentValues := coderdtest.DeploymentValues(t)
|
|
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
// Create a workspace, then stop it so start_workspace has
|
|
// something to start. We intentionally skip starting a test
|
|
// agent — the echo provisioner creates new agent rows for each
|
|
// build, so an agent started for build 1 cannot serve build 3.
|
|
// The tool handles the no-agent case gracefully.
|
|
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
|
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
|
workspace = coderdtest.MustTransitionWorkspace(
|
|
t, client, workspace.ID,
|
|
codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop,
|
|
)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Start workspace test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("start_workspace", "{}"),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Workspace started and ready.")...,
|
|
)
|
|
})
|
|
|
|
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
|
Provider: "openai-compat",
|
|
APIKey: "test-api-key",
|
|
BaseURL: openAIURL,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
contextLimit := int64(4096)
|
|
isDefault := true
|
|
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
|
Provider: "openai-compat",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: &contextLimit,
|
|
IsDefault: &isDefault,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create a chat with the stopped workspace pre-associated.
|
|
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Start the workspace.",
|
|
},
|
|
},
|
|
WorkspaceID: &workspace.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatWithMessages codersdk.ChatWithMessages
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := client.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatWithMessages = got
|
|
return got.Chat.Status == codersdk.ChatStatusWaiting || got.Chat.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitSuperLong, testutil.IntervalFast)
|
|
|
|
if chatWithMessages.Chat.Status == codersdk.ChatStatusError {
|
|
lastError := ""
|
|
if chatWithMessages.Chat.LastError != nil {
|
|
lastError = *chatWithMessages.Chat.LastError
|
|
}
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
|
}
|
|
|
|
// Verify the workspace was started.
|
|
require.NotNil(t, chatWithMessages.Chat.WorkspaceID)
|
|
updatedWorkspace, err := client.Workspace(ctx, workspace.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
|
|
|
|
// Verify start_workspace tool result exists in the chat messages.
|
|
var foundStartWorkspaceResult bool
|
|
for _, message := range chatWithMessages.Messages {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
for _, part := range message.Content {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "start_workspace" {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
require.NoError(t, json.Unmarshal(part.Result, &result))
|
|
started, ok := result["started"].(bool)
|
|
require.True(t, ok)
|
|
require.True(t, started)
|
|
foundStartWorkspaceResult = true
|
|
}
|
|
}
|
|
require.True(t, foundStartWorkspaceResult, "expected start_workspace tool result message")
|
|
|
|
// Verify the LLM received the tool result in its second call.
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
started, ok := result["started"].(bool)
|
|
if ok && started {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include start_workspace tool output")
|
|
}
|
|
|
|
func newTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
func seedChatDependencies(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
) (database.User, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
BaseUrl: "",
|
|
ApiKeyKeyID: sql.NullString{},
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 70,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
return user, model
|
|
}
|
|
|
|
func setOpenAIProviderBaseURL(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
baseURL string,
|
|
) {
|
|
t.Helper()
|
|
|
|
provider, err := db.GetChatProviderByProvider(ctx, "openai")
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
|
ID: provider.ID,
|
|
DisplayName: provider.DisplayName,
|
|
APIKey: provider.APIKey,
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: provider.ApiKeyKeyID,
|
|
Enabled: provider.Enabled,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that blocks until the request context is
|
|
// canceled (i.e. until the chat is interrupted).
|
|
streamStarted := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
// Block until the chat context is canceled by the interrupt.
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
})
|
|
|
|
// Mock webpush dispatcher that records calls.
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-no-push",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to be picked up and start streaming.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.IntervalFast)
|
|
|
|
// Interrupt the chat.
|
|
updated := server.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
|
|
// Wait for the chat to finish processing and return to waiting.
|
|
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid
|
|
}, testutil.IntervalFast)
|
|
|
|
// Verify no web push notification was dispatched.
|
|
require.Equal(t, int32(0), mockPush.dispatchCount.Load(),
|
|
"expected no web push dispatch for an interrupted chat")
|
|
}
|
|
|
|
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
|
|
type mockWebpushDispatcher struct {
|
|
dispatchCount atomic.Int32
|
|
mu sync.Mutex
|
|
lastMessage codersdk.WebpushMessage
|
|
lastUserID uuid.UUID
|
|
}
|
|
|
|
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
|
m.dispatchCount.Add(1)
|
|
m.mu.Lock()
|
|
m.lastMessage = msg
|
|
m.lastUserID = userID
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (*mockWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
|
|
return nil
|
|
}
|
|
|
|
func (*mockWebpushDispatcher) PublicKey() string {
|
|
return "test-vapid-public-key"
|
|
}
|
|
|
|
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that returns a simple successful response.
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
|
|
// Mock webpush dispatcher that captures the dispatched message.
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "push-nav-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for a web push notification to be dispatched. The dispatch
|
|
// happens asynchronously after the DB status is updated, so we need
|
|
// to poll rather than assert immediately.
|
|
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
|
return mockPush.dispatchCount.Load() >= 1
|
|
}, testutil.IntervalFast)
|
|
|
|
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
|
"expected exactly one web push dispatch for a completed chat")
|
|
|
|
// Verify the notification was sent to the correct user.
|
|
mockPush.mu.Lock()
|
|
capturedMsg := mockPush.lastMessage
|
|
capturedUserID := mockPush.lastUserID
|
|
mockPush.mu.Unlock()
|
|
|
|
require.Equal(t, user.ID, capturedUserID,
|
|
"web push should be dispatched to the chat owner")
|
|
|
|
// Verify the Data field contains the correct navigation URL.
|
|
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
|
|
require.Equal(t, expectedURL, capturedMsg.Data["url"],
|
|
"web push Data should contain the chat navigation URL")
|
|
}
|
|
|
|
func TestSuccessfulChatSendsWebPushWithTag(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Set up a mock OpenAI that returns a simple streaming response.
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...)
|
|
})
|
|
|
|
// Mock webpush dispatcher that captures calls.
|
|
mockPush := &mockWebpushDispatcher{}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
WebpushDispatcher: mockPush,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "push-tag-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the web push notification to be dispatched.
|
|
// We poll dispatchCount rather than DB status because the
|
|
// push fires after the status update, creating a small race
|
|
// window.
|
|
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
|
return mockPush.dispatchCount.Load() >= 1
|
|
}, testutil.IntervalFast)
|
|
|
|
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
|
"expected exactly one web push dispatch for a completed chat")
|
|
|
|
// Verify the push notification tag is set to the chat ID for dedup.
|
|
mockPush.mu.Lock()
|
|
capturedMsg := mockPush.lastMessage
|
|
capturedUser := mockPush.lastUserID
|
|
mockPush.mu.Unlock()
|
|
|
|
require.Equal(t, chat.ID.String(), capturedMsg.Tag,
|
|
"push notification tag should equal the chat ID for deduplication")
|
|
require.Equal(t, user.ID, capturedUser,
|
|
"push notification should be dispatched to the chat owner")
|
|
require.Equal(t, "push-tag-test", capturedMsg.Title,
|
|
"push notification title should match the chat title")
|
|
require.Equal(t, "Agent has finished running.", capturedMsg.Body,
|
|
"push notification body should indicate the agent finished")
|
|
}
|
|
|
|
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var requestCount atomic.Int32
|
|
streamStarted := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if requestCount.Add(1) == 1 {
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...)
|
|
})
|
|
|
|
loggerA := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverA := chatd.New(chatd.Config{
|
|
Logger: loggerA,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverA.Close())
|
|
})
|
|
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := serverA.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "shutdown-retry",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.NoError(t, serverA.Close())
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
loggerB := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverB := chatd.New(chatd.Config{
|
|
Logger: loggerB,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverB.Close())
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
return requestCount.Load() >= 2
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestHeaderInjection(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// seedWorkspaceAgent creates the DB entities needed so that
|
|
// GetWorkspaceAgentsInLatestBuildByWorkspaceID returns an
|
|
// agent for the given workspace.
|
|
seedWorkspaceAgent := func(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
ownerID uuid.UUID,
|
|
orgID uuid.UUID,
|
|
) (workspaceID uuid.UUID, agentID uuid.UUID) {
|
|
t.Helper()
|
|
|
|
// TemplateVersion needs its own provisioner job.
|
|
versionJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
OrganizationID: orgID,
|
|
InitiatorID: ownerID,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: orgID,
|
|
CreatedBy: ownerID,
|
|
JobID: versionJob.ID,
|
|
})
|
|
templ := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgID,
|
|
CreatedBy: ownerID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: ownerID,
|
|
OrganizationID: orgID,
|
|
TemplateID: templ.ID,
|
|
})
|
|
buildJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
OrganizationID: orgID,
|
|
InitiatorID: ownerID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
JobID: buildJob.ID,
|
|
BuildNumber: 1,
|
|
InitiatorID: ownerID,
|
|
TemplateVersionID: tv.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: build.JobID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
return ws.ID, agent.ID
|
|
}
|
|
|
|
t.Run("WithParentChat", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
org, err := db.GetDefaultOrganization(ctx)
|
|
require.NoError(t, err)
|
|
|
|
workspaceID, expectedAgentID := seedWorkspaceAgent(t, db, ps, user.ID, org.ID)
|
|
|
|
// Set up the mock OpenAI to return a simple text response
|
|
// so the chat finishes cleanly.
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
// Wire up the mock agent connection so we can capture
|
|
// the headers passed to SetExtraHeaders.
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var capturedHeaders http.Header
|
|
headersCaptured := make(chan struct{})
|
|
|
|
// SetExtraHeaders is called once when the connection
|
|
// is first established.
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).Do(func(h http.Header) {
|
|
capturedHeaders = h
|
|
close(headersCaptured)
|
|
})
|
|
// resolveInstructions calls LS to look for instruction
|
|
// files; return an error so it skips gracefully.
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
|
workspacesdk.LSResponse{}, xerrors.New("not found"),
|
|
).AnyTimes()
|
|
// The connection is closed when the chat finishes.
|
|
mockConn.EXPECT().Close().Return(nil).AnyTimes()
|
|
|
|
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, expectedAgentID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
AgentConn: agentConnFn,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Create a real parent chat so the FK constraint is
|
|
// satisfied.
|
|
parentChat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "parent-chat",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{
|
|
fantasy.TextContent{Text: "parent"},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
|
ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true},
|
|
Title: "header-injection-parent",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{
|
|
fantasy.TextContent{Text: "hello"},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the chat to be processed and headers to be
|
|
// captured.
|
|
select {
|
|
case <-headersCaptured:
|
|
case <-ctx.Done():
|
|
require.FailNow(t, "timed out waiting for SetExtraHeaders")
|
|
}
|
|
|
|
require.Equal(t,
|
|
chat.ID.String(),
|
|
capturedHeaders.Get(workspacesdk.CoderChatIDHeader),
|
|
)
|
|
|
|
ancestorJSON := capturedHeaders.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
|
var ancestorIDs []string
|
|
err = json.Unmarshal([]byte(ancestorJSON), &ancestorIDs)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []string{parentChat.ID.String()}, ancestorIDs)
|
|
})
|
|
|
|
t.Run("WithoutParentChat", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
org, err := db.GetDefaultOrganization(ctx)
|
|
require.NoError(t, err)
|
|
|
|
workspaceID, expectedAgentID := seedWorkspaceAgent(t, db, ps, user.ID, org.ID)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("title")
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("done")...,
|
|
)
|
|
})
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
|
|
var capturedHeaders http.Header
|
|
headersCaptured := make(chan struct{})
|
|
|
|
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).Do(func(h http.Header) {
|
|
capturedHeaders = h
|
|
close(headersCaptured)
|
|
})
|
|
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).Return(
|
|
workspacesdk.LSResponse{}, xerrors.New("not found"),
|
|
).AnyTimes()
|
|
mockConn.EXPECT().Close().Return(nil).AnyTimes()
|
|
|
|
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
require.Equal(t, expectedAgentID, agentID)
|
|
return mockConn, func() {}, nil
|
|
}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
AgentConn: agentConnFn,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Create a chat without a parent — the ancestor header
|
|
// should contain an empty JSON array.
|
|
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
|
Title: "header-injection-no-parent",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{
|
|
fantasy.TextContent{Text: "hello"},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case <-headersCaptured:
|
|
case <-ctx.Done():
|
|
require.FailNow(t, "timed out waiting for SetExtraHeaders")
|
|
}
|
|
|
|
require.Equal(t,
|
|
chat.ID.String(),
|
|
capturedHeaders.Get(workspacesdk.CoderChatIDHeader),
|
|
)
|
|
|
|
// When there is no parent, the code declares
|
|
// var ancestorIDs []string and never appends to it,
|
|
// so json.Marshal produces "null".
|
|
ancestorJSON := capturedHeaders.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
|
var ancestorIDs []string
|
|
err = json.Unmarshal([]byte(ancestorJSON), &ancestorIDs)
|
|
require.NoError(t, err)
|
|
require.Empty(t, ancestorIDs)
|
|
})
|
|
}
|