mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
edee917d88
feat: add AI chat system with agent tools and chat UI Introduce the chatd subsystem and Agents UI for AI-powered chat within Coder workspaces. - Add chatd package with chat loop, message compaction, prompt management, and LLM provider integration (OpenAI, Anthropic) - Add agent tools: create workspace, list/read templates, read/write/ edit files, execute commands - Add chat API endpoints with streaming, message editing, and durable reconnection - Add database schema and migrations for chats, chat messages, chat providers, and chat model configs - Add RBAC policies and dbauthz enforcement for chat resources - Add Agents UI pages with conversation timeline, queued messages list, diff viewer, and model configuration panel - Add comprehensive test coverage including coderd integration tests, chatd unit tests, and Storybook stories - Gate feature behind experiments flag --------- Co-authored-by: Cian Johnston <cian@coder.com> Co-authored-by: Danielle Maywood <danielle@themaywoods.com> Co-authored-by: Jeremy Ruppel <jeremy@coder.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
569 lines
16 KiB
Go
569 lines
16 KiB
Go
package db2sdk_test
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
fantasyopenai "charm.land/fantasy/providers/openai"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/provisionersdk/proto"
|
|
)
|
|
|
|
func TestProvisionerJobStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
job database.ProvisionerJob
|
|
status codersdk.ProvisionerJobStatus
|
|
}{
|
|
{
|
|
name: "canceling",
|
|
job: database.ProvisionerJob{
|
|
CanceledAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
},
|
|
status: codersdk.ProvisionerJobCanceling,
|
|
},
|
|
{
|
|
name: "canceled",
|
|
job: database.ProvisionerJob{
|
|
CanceledAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
CompletedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-30 * time.Second),
|
|
Valid: true,
|
|
},
|
|
},
|
|
status: codersdk.ProvisionerJobCanceled,
|
|
},
|
|
{
|
|
name: "canceled_failed",
|
|
job: database.ProvisionerJob{
|
|
CanceledAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
CompletedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-30 * time.Second),
|
|
Valid: true,
|
|
},
|
|
Error: sql.NullString{String: "badness", Valid: true},
|
|
},
|
|
status: codersdk.ProvisionerJobFailed,
|
|
},
|
|
{
|
|
name: "pending",
|
|
job: database.ProvisionerJob{},
|
|
status: codersdk.ProvisionerJobPending,
|
|
},
|
|
{
|
|
name: "succeeded",
|
|
job: database.ProvisionerJob{
|
|
StartedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
CompletedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-30 * time.Second),
|
|
Valid: true,
|
|
},
|
|
},
|
|
status: codersdk.ProvisionerJobSucceeded,
|
|
},
|
|
{
|
|
name: "completed_failed",
|
|
job: database.ProvisionerJob{
|
|
StartedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
CompletedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-30 * time.Second),
|
|
Valid: true,
|
|
},
|
|
Error: sql.NullString{String: "badness", Valid: true},
|
|
},
|
|
status: codersdk.ProvisionerJobFailed,
|
|
},
|
|
{
|
|
name: "updated",
|
|
job: database.ProvisionerJob{
|
|
StartedAt: sql.NullTime{
|
|
Time: dbtime.Now().Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
UpdatedAt: dbtime.Now(),
|
|
},
|
|
status: codersdk.ProvisionerJobRunning,
|
|
},
|
|
}
|
|
|
|
// Share db for all job inserts.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
for i, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
// Populate standard fields
|
|
now := dbtime.Now().Round(time.Minute)
|
|
tc.job.ID = uuid.New()
|
|
tc.job.CreatedAt = now
|
|
tc.job.UpdatedAt = now
|
|
tc.job.InitiatorID = org.ID
|
|
tc.job.OrganizationID = org.ID
|
|
tc.job.Input = []byte("{}")
|
|
tc.job.Provisioner = database.ProvisionerTypeEcho
|
|
// Unique tags for each job.
|
|
tc.job.Tags = map[string]string{fmt.Sprintf("%d", i): "true"}
|
|
|
|
inserted := dbgen.ProvisionerJob(t, db, nil, tc.job)
|
|
// Make sure the inserted job has the right values.
|
|
require.Equal(t, tc.job.StartedAt.Time.UTC(), inserted.StartedAt.Time.UTC(), "started at")
|
|
require.Equal(t, tc.job.CompletedAt.Time.UTC(), inserted.CompletedAt.Time.UTC(), "completed at")
|
|
require.Equal(t, tc.job.CanceledAt.Time.UTC(), inserted.CanceledAt.Time.UTC(), "canceled at")
|
|
require.Equal(t, tc.job.Error, inserted.Error, "error")
|
|
require.Equal(t, tc.job.ErrorCode, inserted.ErrorCode, "error code")
|
|
|
|
actual := codersdk.ProvisionerJobStatus(inserted.JobStatus)
|
|
require.Equal(t, tc.status, actual)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTemplateVersionParameter_OK(t *testing.T) {
|
|
t.Parallel()
|
|
req := require.New(t)
|
|
|
|
// In this test we're just going to cover the fields that have to get parsed.
|
|
options := []*proto.RichParameterOption{
|
|
{
|
|
Name: "foo",
|
|
Description: "bar",
|
|
Value: "baz",
|
|
Icon: "David Bowie",
|
|
},
|
|
}
|
|
ob, err := json.Marshal(&options)
|
|
req.NoError(err)
|
|
|
|
db := database.TemplateVersionParameter{
|
|
Options: json.RawMessage(ob),
|
|
Description: "_The Rise and Fall of **Ziggy Stardust** and the Spiders from Mars_",
|
|
}
|
|
sdk, err := db2sdk.TemplateVersionParameter(db)
|
|
req.NoError(err)
|
|
req.Len(sdk.Options, 1)
|
|
req.Equal("foo", sdk.Options[0].Name)
|
|
req.Equal("bar", sdk.Options[0].Description)
|
|
req.Equal("baz", sdk.Options[0].Value)
|
|
req.Equal("David Bowie", sdk.Options[0].Icon)
|
|
req.Equal("The Rise and Fall of Ziggy Stardust and the Spiders from Mars", sdk.DescriptionPlaintext)
|
|
}
|
|
|
|
func TestTemplateVersionParameter_BadOptions(t *testing.T) {
|
|
t.Parallel()
|
|
req := require.New(t)
|
|
|
|
db := database.TemplateVersionParameter{
|
|
Options: json.RawMessage("not really JSON!"),
|
|
Description: "_The Rise and Fall of **Ziggy Stardust** and the Spiders from Mars_",
|
|
}
|
|
_, err := db2sdk.TemplateVersionParameter(db)
|
|
req.Error(err)
|
|
}
|
|
|
|
func TestTemplateVersionParameter_BadDescription(t *testing.T) {
|
|
t.Parallel()
|
|
req := require.New(t)
|
|
desc := make([]byte, 300)
|
|
_, err := rand.Read(desc)
|
|
req.NoError(err)
|
|
|
|
db := database.TemplateVersionParameter{
|
|
Options: json.RawMessage("[]"),
|
|
Description: string(desc),
|
|
}
|
|
sdk, err := db2sdk.TemplateVersionParameter(db)
|
|
// Although the markdown parser can return an error, the way we use it should not, even
|
|
// if we feed it garbage data.
|
|
req.NoError(err)
|
|
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
|
|
}
|
|
|
|
func TestAIBridgeInterception(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
interceptionID := uuid.New()
|
|
initiatorID := uuid.New()
|
|
|
|
cases := []struct {
|
|
name string
|
|
interception database.AIBridgeInterception
|
|
initiator database.VisibleUser
|
|
tokenUsages []database.AIBridgeTokenUsage
|
|
userPrompts []database.AIBridgeUserPrompt
|
|
toolUsages []database.AIBridgeToolUsage
|
|
expected codersdk.AIBridgeInterception
|
|
}{
|
|
{
|
|
name: "all_optional_values_set",
|
|
interception: database.AIBridgeInterception{
|
|
ID: interceptionID,
|
|
InitiatorID: initiatorID,
|
|
Provider: "anthropic",
|
|
Model: "claude-3-opus",
|
|
StartedAt: now,
|
|
Metadata: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`{"key":"value"}`),
|
|
Valid: true,
|
|
},
|
|
EndedAt: sql.NullTime{
|
|
Time: now.Add(time.Minute),
|
|
Valid: true,
|
|
},
|
|
APIKeyID: sql.NullString{
|
|
String: "api-key-123",
|
|
Valid: true,
|
|
},
|
|
Client: sql.NullString{
|
|
String: "claude-code/1.0.0",
|
|
Valid: true,
|
|
},
|
|
},
|
|
initiator: database.VisibleUser{
|
|
ID: initiatorID,
|
|
Username: "testuser",
|
|
Name: "Test User",
|
|
AvatarURL: "https://example.com/avatar.png",
|
|
},
|
|
tokenUsages: []database.AIBridgeTokenUsage{
|
|
{
|
|
ID: uuid.New(),
|
|
InterceptionID: interceptionID,
|
|
ProviderResponseID: "resp-123",
|
|
InputTokens: 100,
|
|
OutputTokens: 200,
|
|
Metadata: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`{"cache":"hit"}`),
|
|
Valid: true,
|
|
},
|
|
CreatedAt: now.Add(10 * time.Second),
|
|
},
|
|
},
|
|
userPrompts: []database.AIBridgeUserPrompt{
|
|
{
|
|
ID: uuid.New(),
|
|
InterceptionID: interceptionID,
|
|
ProviderResponseID: "resp-123",
|
|
Prompt: "Hello, world!",
|
|
Metadata: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`{"role":"user"}`),
|
|
Valid: true,
|
|
},
|
|
CreatedAt: now.Add(5 * time.Second),
|
|
},
|
|
},
|
|
toolUsages: []database.AIBridgeToolUsage{
|
|
{
|
|
ID: uuid.New(),
|
|
InterceptionID: interceptionID,
|
|
ProviderResponseID: "resp-123",
|
|
ServerUrl: sql.NullString{
|
|
String: "https://mcp.example.com",
|
|
Valid: true,
|
|
},
|
|
Tool: "read_file",
|
|
Input: `{"path":"/tmp/test.txt"}`,
|
|
Injected: true,
|
|
InvocationError: sql.NullString{
|
|
String: "file not found",
|
|
Valid: true,
|
|
},
|
|
Metadata: pqtype.NullRawMessage{
|
|
RawMessage: json.RawMessage(`{"duration_ms":50}`),
|
|
Valid: true,
|
|
},
|
|
CreatedAt: now.Add(15 * time.Second),
|
|
},
|
|
},
|
|
expected: codersdk.AIBridgeInterception{
|
|
ID: interceptionID,
|
|
Initiator: codersdk.MinimalUser{
|
|
ID: initiatorID,
|
|
Username: "testuser",
|
|
Name: "Test User",
|
|
AvatarURL: "https://example.com/avatar.png",
|
|
},
|
|
Provider: "anthropic",
|
|
Model: "claude-3-opus",
|
|
Metadata: map[string]any{"key": "value"},
|
|
StartedAt: now,
|
|
},
|
|
},
|
|
{
|
|
name: "no_optional_values_set",
|
|
interception: database.AIBridgeInterception{
|
|
ID: interceptionID,
|
|
InitiatorID: initiatorID,
|
|
Provider: "openai",
|
|
Model: "gpt-4",
|
|
StartedAt: now,
|
|
Metadata: pqtype.NullRawMessage{Valid: false},
|
|
EndedAt: sql.NullTime{Valid: false},
|
|
APIKeyID: sql.NullString{Valid: false},
|
|
Client: sql.NullString{Valid: false},
|
|
},
|
|
initiator: database.VisibleUser{
|
|
ID: initiatorID,
|
|
Username: "minimaluser",
|
|
Name: "",
|
|
AvatarURL: "",
|
|
},
|
|
tokenUsages: nil,
|
|
userPrompts: nil,
|
|
toolUsages: nil,
|
|
expected: codersdk.AIBridgeInterception{
|
|
ID: interceptionID,
|
|
Initiator: codersdk.MinimalUser{
|
|
ID: initiatorID,
|
|
Username: "minimaluser",
|
|
Name: "",
|
|
AvatarURL: "",
|
|
},
|
|
Provider: "openai",
|
|
Model: "gpt-4",
|
|
Metadata: nil,
|
|
StartedAt: now,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result := db2sdk.AIBridgeInterception(
|
|
tc.interception,
|
|
tc.initiator,
|
|
tc.tokenUsages,
|
|
tc.userPrompts,
|
|
tc.toolUsages,
|
|
)
|
|
|
|
// Check basic fields.
|
|
require.Equal(t, tc.expected.ID, result.ID)
|
|
require.Equal(t, tc.expected.Initiator, result.Initiator)
|
|
require.Equal(t, tc.expected.Provider, result.Provider)
|
|
require.Equal(t, tc.expected.Model, result.Model)
|
|
require.Equal(t, tc.expected.StartedAt.UTC(), result.StartedAt.UTC())
|
|
require.Equal(t, tc.expected.Metadata, result.Metadata)
|
|
|
|
// Check optional pointer fields.
|
|
if tc.interception.APIKeyID.Valid {
|
|
require.NotNil(t, result.APIKeyID)
|
|
require.Equal(t, tc.interception.APIKeyID.String, *result.APIKeyID)
|
|
} else {
|
|
require.Nil(t, result.APIKeyID)
|
|
}
|
|
|
|
if tc.interception.EndedAt.Valid {
|
|
require.NotNil(t, result.EndedAt)
|
|
require.Equal(t, tc.interception.EndedAt.Time.UTC(), result.EndedAt.UTC())
|
|
} else {
|
|
require.Nil(t, result.EndedAt)
|
|
}
|
|
|
|
if tc.interception.Client.Valid {
|
|
require.NotNil(t, result.Client)
|
|
require.Equal(t, tc.interception.Client.String, *result.Client)
|
|
} else {
|
|
require.Nil(t, result.Client)
|
|
}
|
|
|
|
// Check slices.
|
|
require.Len(t, result.TokenUsages, len(tc.tokenUsages))
|
|
require.Len(t, result.UserPrompts, len(tc.userPrompts))
|
|
require.Len(t, result.ToolUsages, len(tc.toolUsages))
|
|
|
|
// Verify token usages are converted correctly.
|
|
for i, tu := range tc.tokenUsages {
|
|
require.Equal(t, tu.ID, result.TokenUsages[i].ID)
|
|
require.Equal(t, tu.InterceptionID, result.TokenUsages[i].InterceptionID)
|
|
require.Equal(t, tu.ProviderResponseID, result.TokenUsages[i].ProviderResponseID)
|
|
require.Equal(t, tu.InputTokens, result.TokenUsages[i].InputTokens)
|
|
require.Equal(t, tu.OutputTokens, result.TokenUsages[i].OutputTokens)
|
|
}
|
|
|
|
// Verify user prompts are converted correctly.
|
|
for i, up := range tc.userPrompts {
|
|
require.Equal(t, up.ID, result.UserPrompts[i].ID)
|
|
require.Equal(t, up.InterceptionID, result.UserPrompts[i].InterceptionID)
|
|
require.Equal(t, up.ProviderResponseID, result.UserPrompts[i].ProviderResponseID)
|
|
require.Equal(t, up.Prompt, result.UserPrompts[i].Prompt)
|
|
}
|
|
|
|
// Verify tool usages are converted correctly.
|
|
for i, toolUsage := range tc.toolUsages {
|
|
require.Equal(t, toolUsage.ID, result.ToolUsages[i].ID)
|
|
require.Equal(t, toolUsage.InterceptionID, result.ToolUsages[i].InterceptionID)
|
|
require.Equal(t, toolUsage.ProviderResponseID, result.ToolUsages[i].ProviderResponseID)
|
|
require.Equal(t, toolUsage.ServerUrl.String, result.ToolUsages[i].ServerURL)
|
|
require.Equal(t, toolUsage.Tool, result.ToolUsages[i].Tool)
|
|
require.Equal(t, toolUsage.Input, result.ToolUsages[i].Input)
|
|
require.Equal(t, toolUsage.Injected, result.ToolUsages[i].Injected)
|
|
require.Equal(t, toolUsage.InvocationError.String, result.ToolUsages[i].InvocationError)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assistantContent, err := json.Marshal([]fantasy.Content{
|
|
fantasy.ReasoningContent{
|
|
Text: "Plan migration",
|
|
ProviderMetadata: fantasy.ProviderMetadata{
|
|
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
|
ItemID: "reasoning-1",
|
|
Summary: []string{"Plan migration"},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
message := db2sdk.ChatMessage(database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
Role: string(fantasy.MessageRoleAssistant),
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: assistantContent,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
require.Len(t, message.Content, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
|
require.Equal(t, "Plan migration", message.Content[0].Text)
|
|
require.Empty(t, message.Content[0].Title)
|
|
}
|
|
|
|
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
|
|
Text: "Verify schema updates, then apply changes in order.",
|
|
ProviderMetadata: fantasy.ProviderMetadata{
|
|
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
|
|
ItemID: "reasoning-1",
|
|
Summary: []string{
|
|
"**Metadata-derived title**\n\nLonger explanation.",
|
|
},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var envelope map[string]any
|
|
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
|
|
dataValue, ok := envelope["data"].(map[string]any)
|
|
require.True(t, ok)
|
|
dataValue["title"] = "Persisted stream title"
|
|
|
|
encodedReasoning, err := json.Marshal(envelope)
|
|
require.NoError(t, err)
|
|
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
|
|
require.NoError(t, err)
|
|
|
|
message := db2sdk.ChatMessage(database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
Role: string(fantasy.MessageRoleAssistant),
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: assistantContent,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
require.Len(t, message.Content, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
|
|
require.Equal(t, "Persisted stream title", message.Content[0].Title)
|
|
}
|
|
|
|
func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rawContent, err := json.Marshal([]fantasy.Content{
|
|
fantasy.TextContent{Text: "queued text"},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: uuid.New(),
|
|
Content: rawContent,
|
|
CreatedAt: time.Now(),
|
|
})
|
|
|
|
require.Len(t, queued.Content, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
|
|
require.Equal(t, "queued text", queued.Content[0].Text)
|
|
}
|
|
|
|
func TestChatQueuedMessage_FallsBackToTextForLegacyContent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("legacy_string", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: uuid.New(),
|
|
Content: json.RawMessage(`"legacy queued text"`),
|
|
CreatedAt: time.Now(),
|
|
})
|
|
|
|
require.Len(t, queued.Content, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
|
|
require.Equal(t, "legacy queued text", queued.Content[0].Text)
|
|
})
|
|
|
|
t.Run("malformed_payload", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
raw := json.RawMessage(`{"unexpected":"shape"}`)
|
|
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: uuid.New(),
|
|
Content: raw,
|
|
CreatedAt: time.Now(),
|
|
})
|
|
|
|
require.Empty(t, queued.Content)
|
|
})
|
|
}
|