Files
coder/coderd/x/chatd/turn_summary_internal_test.go
2026-05-22 09:50:01 +02:00

196 lines
5.9 KiB
Go

package chatd
import (
"context"
"database/sql"
"encoding/json"
"sync/atomic"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
owner := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: owner.ID,
OrganizationID: org.ID,
})
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
ClientType: database.ChatClientTypeUi,
OwnerID: owner.ID,
LastModelConfigID: modelCfg.ID,
Title: "summary-chat",
})
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{db: db}
server.updateLastTurnSummary(ctx, chat, chat.UpdatedAt, "fresh summary", logger)
fetched, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary)
advancedUpdatedAt := chat.UpdatedAt.Add(time.Second)
_, err = db.UpdateChatStatusPreserveUpdatedAt(ctx, database.UpdateChatStatusPreserveUpdatedAtParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
UpdatedAt: advancedUpdatedAt,
})
require.NoError(t, err)
server.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.UpdatedAt, "stale summary", logger)
fetched, err = db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary)
require.Equal(t, advancedUpdatedAt, fetched.UpdatedAt)
}
func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
owner := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: owner.ID,
OrganizationID: org.ID,
})
provider := dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
Enabled: true,
})
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusPending,
ClientType: database.ChatClientTypeUi,
OwnerID: owner.ID,
LastModelConfigID: modelCfg.ID,
Title: "summary-pending-chat",
})
require.NoError(t, err)
const summary = "Still working on request"
var generateCalls atomic.Int32
model := &chattest.FakeModel{
ProviderName: "openai",
ModelName: "test-model",
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateCalls.Add(1)
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "Unexpected label"},
},
}, nil
},
}
dispatcher := &recordingWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{db: db, webpushDispatcher: dispatcher}
server.maybeFinalizeTurnStatusLabelAndPush(
context.WithoutCancel(ctx),
chat,
database.ChatStatusPending,
"",
runChatResult{
FinalAssistantText: "I finished the queued turn.",
StatusLabelModel: model,
FallbackProvider: model.Provider(),
FallbackModel: model.Model(),
},
logger,
)
server.drainInflight()
fetched, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, sql.NullString{String: summary, Valid: true}, fetched.LastTurnSummary)
require.Equal(t, int32(0), generateCalls.Load())
require.Equal(t, int32(0), dispatcher.dispatchCount.Load())
}
type recordingWebpushDispatcher struct {
dispatchCount atomic.Int32
}
func (d *recordingWebpushDispatcher) Dispatch(
_ context.Context,
_ uuid.UUID,
_ codersdk.WebpushMessage,
) error {
d.dispatchCount.Add(1)
return nil
}
func (*recordingWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
return nil
}
func (*recordingWebpushDispatcher) PublicKey() string {
return "test-vapid-public-key"
}