mirror of
https://github.com/coder/coder.git
synced 2026-06-05 22:18:20 +00:00
b7a7683ac0
## Problem Subscribers connecting to a different replica than the one running the chat see full messages appear but no streaming partials (`message_part` events). The relay mechanism that forwards ephemeral parts across replicas had several bugs. ## Root Causes 1. **`openRelay()` blocked the event loop** — The WebSocket dial (TCP + TLS + HTTP upgrade) to the worker replica ran synchronously inside the select loop. While dialing, no events could be processed, channels filled up, and parts were silently dropped. 2. **Relay drops were permanent** — When the relay WebSocket closed mid-stream, `relayParts` was set to nil and never reopened. No status notification would re-trigger it since the chat was still running on the same worker. 3. **`drainInitial` snapshot race** — The `default` case in the initial drain loop caused the snapshot to be empty if the remote hadn't flushed data yet (common immediately after WebSocket connect). 4. **Duplicate event delivery** — The `preloaded` slice caused snapshot events to be sent both in the return value and re-sent through the channel goroutine. ## Fixes ### `coderd/chatd/chatd.go` (Subscribe method) - **Async relay dial**: `openRelayAsync()` spawns a goroutine to dial the remote replica. The result (channel + cancel func) is delivered on a `relayReadyCh` channel that the select loop reads without blocking. - **Relay reconnection**: When the relay channel closes, a 500ms timer fires. The handler re-checks chat status from the DB and reopens the relay if the chat is still running on a remote worker. - **Snapshot parts via channel**: Relay snapshot + live parts are wrapped into a single channel so they flow through the same path, avoiding races with the select loop. ### `enterprise/coderd/chats.go` (newRemotePartsProvider) - **Timer-based drain**: Replaced `default` with a 1-second timer. After the first event, `Reset(0)` switches to non-blocking drain for remaining buffered events. - **Remove preloaded duplication**: The goroutine now only forwards new events; snapshot events are returned to the caller directly. ## Testing All existing tests pass: - `TestInterruptChatBroadcastsStatusAcrossInstances` - `TestSubscribeSnapshotIncludesStatusEvent` - `TestSubscribeNoPubsubNoDuplicateMessageParts` - `TestSubscribeAfterMessageID` - `TestChatStreamRelay/RelayMessagePartsAcrossReplicas`
1455 lines
46 KiB
Go
1455 lines
46 KiB
Go
package chatd_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/agent/agenttest"
|
|
"github.com/coder/coder/v2/coderd/chatd"
|
|
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/provisioner/echo"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replicaA := newTestServer(t, db, ps, uuid.New())
|
|
replicaB := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-me",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
runningWorker := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
updated := replicaA.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
|
return false
|
|
}
|
|
return event.Status.Status == codersdk.ChatStatusWaiting
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "db-transition",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
updated := replica.InterruptChat(ctx, chat)
|
|
require.Equal(t, database.ChatStatusWaiting, updated.Status)
|
|
require.False(t, updated.WorkerID.Valid)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "heartbeat-ownership",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
|
ID: chat.ID,
|
|
WorkerID: uuid.New(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(0), rows)
|
|
|
|
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
|
ID: chat.ID,
|
|
WorkerID: workerID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), rows)
|
|
}
|
|
|
|
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "queue-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
workerID := uuid.New()
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, result.Queued)
|
|
require.NotNil(t, result.QueuedMessage)
|
|
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
|
|
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
|
|
require.True(t, result.Chat.WorkerID.Valid)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 1)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
}
|
|
|
|
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "interrupt-when-busy",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, result.Queued)
|
|
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
|
require.False(t, result.Chat.WorkerID.Valid)
|
|
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
|
require.False(t, fromDB.WorkerID.Valid)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 0)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 2)
|
|
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
|
}
|
|
|
|
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "edit-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "original"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, initialMessages, 1)
|
|
editedMessageID := initialMessages[0].ID
|
|
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "follow-up"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
|
ChatID: chat.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "another"}},
|
|
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: chat.ID,
|
|
Content: json.RawMessage(`"queued"`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: editedMessageID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, editedMessageID, editResult.Message.ID)
|
|
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
|
|
require.False(t, editResult.Chat.WorkerID.Valid)
|
|
|
|
editedSDK := db2sdk.ChatMessage(editResult.Message)
|
|
require.Len(t, editedSDK.Content, 1)
|
|
require.Equal(t, "edited", editedSDK.Content[0].Text)
|
|
|
|
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, messages, 1)
|
|
require.Equal(t, editedMessageID, messages[0].ID)
|
|
onlyMessage := db2sdk.ChatMessage(messages[0])
|
|
require.Len(t, onlyMessage.Content, 1)
|
|
require.Equal(t, "edited", onlyMessage.Content[0].Text)
|
|
|
|
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, 0)
|
|
|
|
chatFromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, chatFromDB.Status)
|
|
require.False(t, chatFromDB.WorkerID.Valid)
|
|
}
|
|
|
|
func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "missing-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: 999999,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
|
|
}
|
|
|
|
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "non-user-edited-message",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
assistantMessage, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "assistant",
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`"assistant"`),
|
|
Valid: true,
|
|
},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
|
ChatID: chat.ID,
|
|
EditedMessageID: assistantMessage.ID,
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
|
|
}
|
|
|
|
func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Use a very short stale threshold so the periodic recovery
|
|
// kicks in quickly during the test.
|
|
staleAfter := 500 * time.Millisecond
|
|
|
|
// Create a chat and simulate a dead worker by setting the chat
|
|
// to running with a heartbeat in the past.
|
|
deadWorkerID := uuid.New()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica. Its startup recovery will reset the
|
|
// chat (since the heartbeat is old), but the key point is that
|
|
// the periodic loop also recovers newly-stale chats.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
|
InFlightChatStaleAfter: staleAfter,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// The startup recovery should have already reset our stale
|
|
// chat.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
// Now simulate a second stale chat appearing AFTER startup.
|
|
// This tests the periodic recovery, not just the startup one.
|
|
deadWorkerID2 := uuid.New()
|
|
chat2, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "stale-recovery-periodic-2",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat2.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadWorkerID2, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The periodic stale recovery loop (running at staleAfter/5 =
|
|
// 100ms intervals) should pick this up without a restart.
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat2.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Simulate a chat left running by a dead replica with a stale
|
|
// heartbeat (well beyond the stale threshold).
|
|
deadReplicaID := uuid.New()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "orphaned-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Set the heartbeat far in the past so it's definitely stale.
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: deadReplicaID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a new replica — it should recover the stale chat on
|
|
// startup.
|
|
newReplica := newTestServer(t, db, ps, uuid.New())
|
|
_ = newReplica
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat in waiting status — this should NOT be touched
|
|
// by stale recovery.
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "waiting-chat",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Start a replica with a short stale threshold.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
|
InFlightChatStaleAfter: 500 * time.Millisecond,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
|
|
// Wait long enough for multiple periodic recovery cycles to
|
|
// run (staleAfter/5 = 100ms intervals).
|
|
require.Never(t, func() bool {
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status != database.ChatStatusWaiting
|
|
}, time.Second, testutil.IntervalFast,
|
|
"waiting chat should not be modified by stale recovery")
|
|
}
|
|
|
|
func TestUpdateChatStatusPersistsLastError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
_ = newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
Title: "error-persisted",
|
|
LastModelConfigID: model.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Simulate a chat that failed with an error.
|
|
errorMessage := "stream response: status 500: internal server error"
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusError,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{String: errorMessage, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, chat.Status)
|
|
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, chat.LastError)
|
|
|
|
// Verify the error is persisted when re-read from the database.
|
|
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusError, fromDB.Status)
|
|
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, fromDB.LastError)
|
|
|
|
// Verify the error is cleared when the chat transitions to a
|
|
// non-error status (e.g. pending after a retry).
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ChatStatusPending, chat.Status)
|
|
require.False(t, chat.LastError.Valid)
|
|
|
|
fromDB, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fromDB.LastError.Valid)
|
|
}
|
|
|
|
func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "status-snapshot",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// The first event in the snapshot must be a status event.
|
|
require.NotEmpty(t, snapshot)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
|
|
require.NotNil(t, snapshot[0].Status)
|
|
require.Equal(t, codersdk.ChatStatusPending, snapshot[0].Status.Status)
|
|
}
|
|
|
|
func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use nil pubsub to force the no-pubsub path.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, nil, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "no-dup-parts",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Snapshot should have events (at minimum: status + message).
|
|
require.NotEmpty(t, snapshot)
|
|
|
|
// The events channel should NOT immediately produce any
|
|
// events — the snapshot already contained everything. Before
|
|
// the fix, localSnapshot was replayed into the channel,
|
|
// causing duplicates.
|
|
select {
|
|
case event, ok := <-events:
|
|
if ok {
|
|
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
|
|
}
|
|
// Channel closed without events is fine.
|
|
case <-time.After(200 * time.Millisecond):
|
|
// No events — correct behavior.
|
|
}
|
|
}
|
|
|
|
func TestSubscribeAfterMessageID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
replica := newTestServer(t, db, ps, uuid.New())
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat — this inserts one initial "user" message.
|
|
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "after-id-test",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert two more messages so we have three total visible
|
|
// messages (the initial user message plus these two).
|
|
msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "assistant",
|
|
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
|
Role: "user",
|
|
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Control: Subscribe with afterMessageID=0 returns ALL messages.
|
|
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
cancelAll()
|
|
|
|
allMessages := filterMessageEvents(allSnapshot)
|
|
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
|
|
|
|
// Subscribe with afterMessageID set to the second message's ID.
|
|
// Only the third message (inserted after msg2) should appear.
|
|
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
|
|
require.True(t, ok)
|
|
cancelPartial()
|
|
|
|
partialMessages := filterMessageEvents(partialSnapshot)
|
|
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
|
|
require.Equal(t, "user", partialMessages[0].Message.Role)
|
|
}
|
|
|
|
// filterMessageEvents returns only the Message-type events from a
|
|
// snapshot slice, which is useful for ignoring status / queue events.
|
|
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
|
|
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
|
|
return e.Type == codersdk.ChatStreamEventTypeMessage
|
|
})
|
|
}
|
|
|
|
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
deploymentValues := coderdtest.DeploymentValues(t)
|
|
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
|
|
client := coderdtest.New(t, &coderdtest.Options{
|
|
DeploymentValues: deploymentValues,
|
|
IncludeProvisionerDaemon: true,
|
|
})
|
|
user := coderdtest.CreateFirstUser(t, client)
|
|
|
|
agentToken := uuid.NewString()
|
|
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
|
Parse: echo.ParseComplete,
|
|
ProvisionPlan: echo.PlanComplete,
|
|
ProvisionApply: echo.ApplyComplete,
|
|
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
|
|
})
|
|
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
|
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
|
|
|
// Start the test workspace agent so create_workspace can wait for
|
|
// the agent to become reachable before returning.
|
|
_ = agenttest.New(t, client.URL, agentToken)
|
|
|
|
workspaceName := "chat-ws-" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
|
|
createWorkspaceArgs := fmt.Sprintf(
|
|
`{"template_id":%q,"name":%q}`,
|
|
template.ID.String(),
|
|
workspaceName,
|
|
)
|
|
|
|
var streamedCallCount atomic.Int32
|
|
var streamedCallsMu sync.Mutex
|
|
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
|
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if !req.Stream {
|
|
return chattest.OpenAINonStreamingResponse("Create workspace test")
|
|
}
|
|
|
|
streamedCallsMu.Lock()
|
|
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
|
streamedCallsMu.Unlock()
|
|
|
|
if streamedCallCount.Add(1) == 1 {
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAIToolCallChunk("create_workspace", createWorkspaceArgs),
|
|
)
|
|
}
|
|
return chattest.OpenAIStreamingResponse(
|
|
chattest.OpenAITextChunks("Workspace created and ready.")...,
|
|
)
|
|
})
|
|
|
|
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
|
Provider: "openai-compat",
|
|
APIKey: "test-api-key",
|
|
BaseURL: openAIURL,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
contextLimit := int64(4096)
|
|
isDefault := true
|
|
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
|
Provider: "openai-compat",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: &contextLimit,
|
|
IsDefault: &isDefault,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
|
Content: []codersdk.ChatInputPart{
|
|
{
|
|
Type: codersdk.ChatInputPartTypeText,
|
|
Text: "Create a workspace from the template and continue.",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var chatWithMessages codersdk.ChatWithMessages
|
|
require.Eventually(t, func() bool {
|
|
got, getErr := client.GetChat(ctx, chat.ID)
|
|
if getErr != nil {
|
|
return false
|
|
}
|
|
chatWithMessages = got
|
|
return got.Chat.Status == codersdk.ChatStatusWaiting || got.Chat.Status == codersdk.ChatStatusError
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
if chatWithMessages.Chat.Status == codersdk.ChatStatusError {
|
|
lastError := ""
|
|
if chatWithMessages.Chat.LastError != nil {
|
|
lastError = *chatWithMessages.Chat.LastError
|
|
}
|
|
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
|
}
|
|
|
|
require.NotNil(t, chatWithMessages.Chat.WorkspaceID)
|
|
workspaceID := *chatWithMessages.Chat.WorkspaceID
|
|
workspace, err := client.Workspace(ctx, workspaceID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceName, workspace.Name)
|
|
|
|
var foundCreateWorkspaceResult bool
|
|
for _, message := range chatWithMessages.Messages {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
for _, part := range message.Content {
|
|
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "create_workspace" {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
require.NoError(t, json.Unmarshal(part.Result, &result))
|
|
created, ok := result["created"].(bool)
|
|
require.True(t, ok)
|
|
require.True(t, created)
|
|
foundCreateWorkspaceResult = true
|
|
}
|
|
}
|
|
require.True(t, foundCreateWorkspaceResult, "expected create_workspace tool result message")
|
|
|
|
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
|
streamedCallsMu.Lock()
|
|
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
|
streamedCallsMu.Unlock()
|
|
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
|
|
|
|
var foundToolResultInSecondCall bool
|
|
for _, message := range recordedStreamCalls[1] {
|
|
if message.Role != "tool" {
|
|
continue
|
|
}
|
|
if !json.Valid([]byte(message.Content)) {
|
|
continue
|
|
}
|
|
var result map[string]any
|
|
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
|
|
continue
|
|
}
|
|
created, ok := result["created"].(bool)
|
|
if ok && created {
|
|
foundToolResultInSecondCall = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include create_workspace tool output")
|
|
}
|
|
|
|
func newTestServer(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
func newTestServerWithRelay(
|
|
t *testing.T,
|
|
db database.Store,
|
|
ps dbpubsub.Pubsub,
|
|
replicaID uuid.UUID,
|
|
provider chatd.RemotePartsProvider,
|
|
) *chatd.Server {
|
|
t.Helper()
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := chatd.New(chatd.Config{
|
|
Logger: logger,
|
|
Database: db,
|
|
ReplicaID: replicaID,
|
|
Pubsub: ps,
|
|
RemotePartsProvider: provider,
|
|
PendingChatAcquireInterval: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, server.Close())
|
|
})
|
|
return server
|
|
}
|
|
|
|
func seedChatDependencies(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
) (database.User, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
BaseUrl: "",
|
|
ApiKeyKeyID: sql.NullString{},
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 70,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
return user, model
|
|
}
|
|
|
|
func setOpenAIProviderBaseURL(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
baseURL string,
|
|
) {
|
|
t.Helper()
|
|
|
|
provider, err := db.GetChatProviderByProvider(ctx, "openai")
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
|
|
ID: provider.ID,
|
|
DisplayName: provider.DisplayName,
|
|
APIKey: provider.APIKey,
|
|
BaseUrl: baseURL,
|
|
ApiKeyKeyID: provider.ApiKeyKeyID,
|
|
Enabled: provider.Enabled,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestSubscribeRelayReconnectsOnDrop(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
workerID := uuid.New()
|
|
subscriberID := uuid.New()
|
|
|
|
var callCount atomic.Int32
|
|
|
|
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
|
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
|
) {
|
|
call := callCount.Add(1)
|
|
ch := make(chan codersdk.ChatStreamEvent, 10)
|
|
if call == 1 {
|
|
// First relay: send a part then close to simulate a drop.
|
|
ch <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first-relay"},
|
|
},
|
|
}
|
|
close(ch)
|
|
} else {
|
|
// Second relay: send a different part, keep open.
|
|
ch <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "second-relay"},
|
|
},
|
|
}
|
|
// Don't close — keep alive so the subscriber stays connected.
|
|
}
|
|
return nil, ch, func() {}, nil
|
|
}
|
|
|
|
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat and mark it as running on a remote worker.
|
|
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "relay-reconnect",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Should get the first relay part.
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
|
event.MessagePart != nil &&
|
|
event.MessagePart.Part.Text == "first-relay" {
|
|
return true
|
|
}
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
// After the first relay closes, a reconnection should happen and
|
|
// deliver the second relay part.
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
|
event.MessagePart != nil &&
|
|
event.MessagePart.Part.Text == "second-relay" {
|
|
return true
|
|
}
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.GreaterOrEqual(t, int(callCount.Load()), 2)
|
|
}
|
|
|
|
func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
workerID := uuid.New()
|
|
subscriberID := uuid.New()
|
|
|
|
dialStarted := make(chan struct{})
|
|
dialContinue := make(chan struct{})
|
|
|
|
provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
|
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
|
) {
|
|
// Signal that the dial has started, then block until released.
|
|
select {
|
|
case <-dialStarted:
|
|
default:
|
|
close(dialStarted)
|
|
}
|
|
select {
|
|
case <-dialContinue:
|
|
case <-ctx.Done():
|
|
return nil, nil, nil, ctx.Err()
|
|
}
|
|
ch := make(chan codersdk.ChatStreamEvent, 10)
|
|
return nil, ch, func() {}, nil
|
|
}
|
|
|
|
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat in pending status.
|
|
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "relay-async-nonblock",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Subscribe before the chat is marked running so the relay opens
|
|
// via pubsub notification (openRelayAsync path).
|
|
_, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// Now mark the chat as running on a remote worker. This publishes
|
|
// a status notification which triggers openRelayAsync on the
|
|
// subscriber.
|
|
notify := coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusRunning),
|
|
WorkerID: workerID.String(),
|
|
}
|
|
payload, err := json.Marshal(notify)
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload)
|
|
require.NoError(t, err)
|
|
|
|
// Wait for the relay dial to actually start (blocking in the
|
|
// provider).
|
|
select {
|
|
case <-dialStarted:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for relay dial to start")
|
|
}
|
|
|
|
// While the relay is still dialing (provider is blocked), publish
|
|
// another status change. If openRelayAsync blocked the select loop
|
|
// this event would never arrive.
|
|
statusNotify := coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusWaiting),
|
|
}
|
|
statusPayload, err := json.Marshal(statusNotify)
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload)
|
|
require.NoError(t, err)
|
|
|
|
// The waiting status event should arrive promptly despite the
|
|
// relay still dialing.
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
return event.Type == codersdk.ChatStreamEventTypeStatus &&
|
|
event.Status != nil &&
|
|
event.Status.Status == codersdk.ChatStatusWaiting
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
// Unblock the relay dial so the test can clean up.
|
|
close(dialContinue)
|
|
}
|
|
|
|
func TestSubscribeRelaySnapshotDelivered(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
workerID := uuid.New()
|
|
subscriberID := uuid.New()
|
|
|
|
provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) (
|
|
[]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error,
|
|
) {
|
|
// Return a non-empty snapshot with two parts.
|
|
snapshot := []codersdk.ChatStreamEvent{
|
|
{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-one"},
|
|
},
|
|
},
|
|
{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "snap-two"},
|
|
},
|
|
},
|
|
}
|
|
ch := make(chan codersdk.ChatStreamEvent, 10)
|
|
// Also send a live part after the snapshot.
|
|
ch <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "live-part"},
|
|
},
|
|
}
|
|
return snapshot, ch, func() {}, nil
|
|
}
|
|
|
|
subscriber := newTestServerWithRelay(t, db, ps, subscriberID, provider)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
|
|
// Create a chat already running on a remote worker.
|
|
chat, err := subscriber.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "relay-snapshot",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
initialSnapshot, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0)
|
|
require.True(t, ok)
|
|
t.Cleanup(cancel)
|
|
|
|
// The initial snapshot should contain the two relay snapshot parts.
|
|
var snapshotTexts []string
|
|
for _, event := range initialSnapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart && event.MessagePart != nil {
|
|
snapshotTexts = append(snapshotTexts, event.MessagePart.Part.Text)
|
|
}
|
|
}
|
|
require.Contains(t, snapshotTexts, "snap-one")
|
|
require.Contains(t, snapshotTexts, "snap-two")
|
|
|
|
// The live part should arrive on the events channel.
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case event := <-events:
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart &&
|
|
event.MessagePart != nil &&
|
|
event.MessagePart.Part.Text == "live-part" {
|
|
return true
|
|
}
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
var requestCount atomic.Int32
|
|
streamStarted := make(chan struct{})
|
|
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
if requestCount.Add(1) == 1 {
|
|
chunks := make(chan chattest.OpenAIChunk, 1)
|
|
go func() {
|
|
defer close(chunks)
|
|
chunks <- chattest.OpenAITextChunks("partial")[0]
|
|
select {
|
|
case <-streamStarted:
|
|
default:
|
|
close(streamStarted)
|
|
}
|
|
<-req.Context().Done()
|
|
}()
|
|
return chattest.OpenAIResponse{StreamingChunks: chunks}
|
|
}
|
|
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...)
|
|
})
|
|
|
|
loggerA := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverA := chatd.New(chatd.Config{
|
|
Logger: loggerA,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverA.Close())
|
|
})
|
|
|
|
user, model := seedChatDependencies(ctx, t, db)
|
|
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
|
|
|
chat, err := serverA.CreateChat(ctx, chatd.CreateOptions{
|
|
OwnerID: user.ID,
|
|
Title: "shutdown-retry",
|
|
ModelConfigID: model.ID,
|
|
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
select {
|
|
case <-streamStarted:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.NoError(t, serverA.Close())
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusPending &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
loggerB := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
serverB := chatd.New(chatd.Config{
|
|
Logger: loggerB,
|
|
Database: db,
|
|
ReplicaID: uuid.New(),
|
|
Pubsub: ps,
|
|
PendingChatAcquireInterval: 10 * time.Millisecond,
|
|
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
|
})
|
|
t.Cleanup(func() {
|
|
require.NoError(t, serverB.Close())
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
return requestCount.Load() >= 2
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
|
|
require.Eventually(t, func() bool {
|
|
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
|
if dbErr != nil {
|
|
return false
|
|
}
|
|
return fromDB.Status == database.ChatStatusWaiting &&
|
|
!fromDB.WorkerID.Valid &&
|
|
!fromDB.LastError.Valid
|
|
}, testutil.WaitMedium, testutil.IntervalFast)
|
|
}
|