mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
2f855904be
- Adds chat-related dbgen generators covering defaults, overrides, and message field mapping. - Replaces raw single-row chat, message, provider, and model-config setup in tests with dbgen helpers. - Simplifies chat seed helpers after moving fixture setup into dbgen. > Generated with [Coder Agents](https://coder.com/agents).
452 lines
16 KiB
Go
452 lines
16 KiB
Go
package dbgen_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"testing"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
)
|
|
|
|
func TestGenerator(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("AuditLog", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
_ = dbgen.AuditLog(t, db, database.AuditLog{})
|
|
logs := must(db.GetAuditLogsOffset(context.Background(), database.GetAuditLogsOffsetParams{LimitOpt: 1}))
|
|
require.Len(t, logs, 1)
|
|
})
|
|
|
|
t.Run("APIKey", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp, _ := dbgen.APIKey(t, db, database.APIKey{})
|
|
require.Equal(t, exp, must(db.GetAPIKeyByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("File", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp := dbgen.File(t, db, database.File{})
|
|
require.Equal(t, exp, must(db.GetFileByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("UserLink", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
u := dbgen.User(t, db, database.User{})
|
|
exp := dbgen.UserLink(t, db, database.UserLink{UserID: u.ID})
|
|
require.Equal(t, exp, must(db.GetUserLinkByLinkedID(context.Background(), exp.LinkedID)))
|
|
})
|
|
|
|
t.Run("GitAuthLink", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp := dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{})
|
|
require.Equal(t, exp, must(db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
|
|
ProviderID: exp.ProviderID,
|
|
UserID: exp.UserID,
|
|
})))
|
|
})
|
|
|
|
t.Run("WorkspaceResource", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{})
|
|
require.Equal(t, exp, must(db.GetWorkspaceResourceByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("WorkspaceApp", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{})
|
|
require.Equal(t, exp, must(db.GetWorkspaceAppsByAgentID(context.Background(), exp.AgentID))[0])
|
|
})
|
|
|
|
t.Run("WorkspaceResourceMetadata", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceResourceMetadatums(t, db, database.WorkspaceResourceMetadatum{})
|
|
require.Equal(t, exp, must(db.GetWorkspaceResourceMetadataByResourceIDs(context.Background(), []uuid.UUID{exp[0].WorkspaceResourceID})))
|
|
})
|
|
|
|
t.Run("WorkspaceProxy", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp, secret := dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
require.Len(t, secret, 64)
|
|
require.Equal(t, exp, must(db.GetWorkspaceProxyByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("Job", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{})
|
|
require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("Group", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.Group(t, db, database.Group{})
|
|
require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("GroupMember", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
g := dbgen.Group(t, db, database.Group{})
|
|
u := dbgen.User(t, db, database.User{})
|
|
gm := dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
|
|
exp := []database.GroupMember{gm}
|
|
|
|
require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), database.GetGroupMembersByGroupIDParams{
|
|
GroupID: g.ID,
|
|
IncludeSystem: false,
|
|
})))
|
|
})
|
|
|
|
t.Run("Organization", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp := dbgen.Organization(t, db, database.Organization{})
|
|
require.Equal(t, exp, must(db.GetOrganizationByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("OrganizationMember", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
o := dbgen.Organization(t, db, database.Organization{})
|
|
u := dbgen.User(t, db, database.User{})
|
|
exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{OrganizationID: o.ID, UserID: u.ID})
|
|
require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{
|
|
OrganizationID: exp.OrganizationID,
|
|
UserID: exp.UserID,
|
|
}))).OrganizationMember)
|
|
})
|
|
|
|
t.Run("Workspace", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
u := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: u.ID,
|
|
})
|
|
exp := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: u.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
w := must(db.GetWorkspaceByID(context.Background(), exp.ID))
|
|
table := database.WorkspaceTable{
|
|
ID: w.ID,
|
|
CreatedAt: w.CreatedAt,
|
|
UpdatedAt: w.UpdatedAt,
|
|
OwnerID: w.OwnerID,
|
|
OrganizationID: w.OrganizationID,
|
|
TemplateID: w.TemplateID,
|
|
Deleted: w.Deleted,
|
|
Name: w.Name,
|
|
AutostartSchedule: w.AutostartSchedule,
|
|
Ttl: w.Ttl,
|
|
LastUsedAt: w.LastUsedAt,
|
|
DormantAt: w.DormantAt,
|
|
DeletingAt: w.DeletingAt,
|
|
AutomaticUpdates: w.AutomaticUpdates,
|
|
Favorite: w.Favorite,
|
|
GroupACL: database.WorkspaceACL{},
|
|
UserACL: database.WorkspaceACL{},
|
|
}
|
|
require.Equal(t, exp, table)
|
|
})
|
|
|
|
t.Run("WorkspaceAgent", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{})
|
|
require.Equal(t, exp, must(db.GetWorkspaceAgentByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("Template", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.Template(t, db, database.Template{})
|
|
require.Equal(t, exp, must(db.GetTemplateByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("TemplateVersion", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
|
|
require.Equal(t, exp, must(db.GetTemplateVersionByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("WorkspaceBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{})
|
|
require.Equal(t, exp, must(db.GetWorkspaceBuildByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("User", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
exp := dbgen.User(t, db, database.User{})
|
|
require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID)))
|
|
})
|
|
|
|
t.Run("ServiceAccountUser", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
user := dbgen.User(t, db, database.User{
|
|
IsServiceAccount: true,
|
|
Email: "should-be-overridden@coder.com",
|
|
LoginType: database.LoginTypePassword,
|
|
})
|
|
require.True(t, user.IsServiceAccount)
|
|
require.Empty(t, user.Email)
|
|
require.Equal(t, database.LoginTypeNone, user.LoginType)
|
|
require.Equal(t, user, must(db.GetUserByID(context.Background(), user.ID)))
|
|
})
|
|
|
|
t.Run("SSHKey", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.GitSSHKey(t, db, database.GitSSHKey{})
|
|
require.Equal(t, exp, must(db.GetGitSSHKey(context.Background(), exp.UserID)))
|
|
})
|
|
|
|
t.Run("WorkspaceBuildParameters", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.WorkspaceBuildParameters(t, db, []database.WorkspaceBuildParameter{{Name: "name1", Value: "value1"}, {Name: "name2", Value: "value2"}, {Name: "name3", Value: "value3"}})
|
|
require.Equal(t, exp, must(db.GetWorkspaceBuildParameters(context.Background(), exp[0].WorkspaceBuildID)))
|
|
})
|
|
|
|
t.Run("TemplateVersionParameter", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
dbtestutil.DisableForeignKeysAndTriggers(t, db)
|
|
exp := dbgen.TemplateVersionParameter(t, db, database.TemplateVersionParameter{})
|
|
actual := must(db.GetTemplateVersionParameters(context.Background(), exp.TemplateVersionID))
|
|
require.Len(t, actual, 1)
|
|
require.Equal(t, exp, actual[0])
|
|
})
|
|
|
|
t.Run("ChatProvider", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Defaults.
|
|
p := dbgen.ChatProvider(t, db, database.ChatProvider{})
|
|
require.NotEqual(t, uuid.Nil, p.ID)
|
|
require.Equal(t, "openai", p.Provider)
|
|
require.Equal(t, "openai", p.DisplayName)
|
|
require.True(t, p.Enabled)
|
|
require.True(t, p.CentralApiKeyEnabled)
|
|
require.Equal(t, "test-key", p.APIKey)
|
|
|
|
// Overrides.
|
|
p2 := dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: "anthropic",
|
|
DisplayName: "Claude",
|
|
APIKey: "sk-custom",
|
|
})
|
|
require.Equal(t, "anthropic", p2.Provider)
|
|
require.Equal(t, "Claude", p2.DisplayName)
|
|
require.Equal(t, "sk-custom", p2.APIKey)
|
|
|
|
p3 := dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: "openrouter",
|
|
}, func(params *database.InsertChatProviderParams) {
|
|
params.APIKey = ""
|
|
})
|
|
require.Empty(t, p3.APIKey)
|
|
})
|
|
|
|
t.Run("ChatModelConfig", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
_ = dbgen.ChatProvider(t, db, database.ChatProvider{})
|
|
|
|
// Defaults.
|
|
cfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
|
require.NotEqual(t, uuid.Nil, cfg.ID)
|
|
require.Equal(t, "openai", cfg.Provider)
|
|
require.Equal(t, "gpt-4o-mini", cfg.Model)
|
|
require.Equal(t, "Test Model", cfg.DisplayName)
|
|
require.True(t, cfg.Enabled)
|
|
require.Equal(t, int64(128000), cfg.ContextLimit)
|
|
require.Equal(t, int32(70), cfg.CompressionThreshold)
|
|
|
|
// Overrides.
|
|
_ = dbgen.ChatProvider(t, db, database.ChatProvider{Provider: "anthropic"})
|
|
cfg2 := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: "anthropic",
|
|
Model: "claude-4",
|
|
ContextLimit: 200000,
|
|
})
|
|
require.Equal(t, "anthropic", cfg2.Provider)
|
|
require.Equal(t, "claude-4", cfg2.Model)
|
|
require.Equal(t, int64(200000), cfg2.ContextLimit)
|
|
})
|
|
|
|
t.Run("Chat", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
u := dbgen.User(t, db, database.User{})
|
|
o := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: u.ID,
|
|
OrganizationID: o.ID,
|
|
})
|
|
p := dbgen.ChatProvider(t, db, database.ChatProvider{})
|
|
m := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{Provider: p.Provider})
|
|
|
|
// Defaults.
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OwnerID: u.ID,
|
|
OrganizationID: o.ID,
|
|
LastModelConfigID: m.ID,
|
|
})
|
|
require.NotEqual(t, uuid.Nil, chat.ID)
|
|
require.Equal(t, database.ChatStatusWaiting, chat.Status)
|
|
require.Equal(t, database.ChatClientTypeUi, chat.ClientType)
|
|
require.NotEmpty(t, chat.Title)
|
|
|
|
// Overrides.
|
|
chat2 := dbgen.Chat(t, db, database.Chat{
|
|
OwnerID: u.ID,
|
|
OrganizationID: o.ID,
|
|
LastModelConfigID: m.ID,
|
|
Title: "custom-title",
|
|
Status: database.ChatStatusRunning,
|
|
})
|
|
require.Equal(t, "custom-title", chat2.Title)
|
|
require.Equal(t, database.ChatStatusRunning, chat2.Status)
|
|
})
|
|
|
|
t.Run("ChatMessage", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
u := dbgen.User(t, db, database.User{})
|
|
o := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: u.ID,
|
|
OrganizationID: o.ID,
|
|
})
|
|
p := dbgen.ChatProvider(t, db, database.ChatProvider{})
|
|
m := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{Provider: p.Provider})
|
|
chat := dbgen.Chat(t, db, database.Chat{
|
|
OwnerID: u.ID,
|
|
OrganizationID: o.ID,
|
|
LastModelConfigID: m.ID,
|
|
})
|
|
|
|
// Defaults.
|
|
msg := dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
})
|
|
require.NotZero(t, msg.ID)
|
|
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
|
|
require.Equal(t, database.ChatMessageVisibilityBoth, msg.Visibility)
|
|
require.Equal(t, chatprompt.CurrentContentVersion, msg.ContentVersion)
|
|
|
|
// Overrides.
|
|
rawContent := json.RawMessage(`[{"type":"text","text":"hello"}]`)
|
|
msg2 := dbgen.ChatMessage(t, db, database.ChatMessage{
|
|
ChatID: chat.ID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: rawContent,
|
|
Valid: true,
|
|
},
|
|
InputTokens: sql.NullInt64{Int64: 11, Valid: true},
|
|
OutputTokens: sql.NullInt64{Int64: 22, Valid: true},
|
|
TotalTokens: sql.NullInt64{Int64: 33, Valid: true},
|
|
ReasoningTokens: sql.NullInt64{Int64: 44, Valid: true},
|
|
CacheCreationTokens: sql.NullInt64{Int64: 55, Valid: true},
|
|
CacheReadTokens: sql.NullInt64{Int64: 66, Valid: true},
|
|
ContextLimit: sql.NullInt64{Int64: 77, Valid: true},
|
|
Compressed: true,
|
|
TotalCostMicros: sql.NullInt64{Int64: 88, Valid: true},
|
|
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
|
})
|
|
require.Equal(t, database.ChatMessageRoleAssistant, msg2.Role)
|
|
require.True(t, msg2.Content.Valid)
|
|
require.JSONEq(t, string(rawContent), string(msg2.Content.RawMessage))
|
|
require.Equal(t, sql.NullInt64{Int64: 11, Valid: true}, msg2.InputTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 22, Valid: true}, msg2.OutputTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 33, Valid: true}, msg2.TotalTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 44, Valid: true}, msg2.ReasoningTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 55, Valid: true}, msg2.CacheCreationTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 66, Valid: true}, msg2.CacheReadTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 77, Valid: true}, msg2.ContextLimit)
|
|
require.True(t, msg2.Compressed)
|
|
require.Equal(t, sql.NullInt64{Int64: 88, Valid: true}, msg2.TotalCostMicros)
|
|
require.Equal(t, sql.NullString{String: "resp-123", Valid: true}, msg2.ProviderResponseID)
|
|
})
|
|
|
|
t.Run("MCPServerConfig", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Defaults.
|
|
cfg := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{})
|
|
require.NotEqual(t, uuid.Nil, cfg.ID)
|
|
require.Equal(t, "streamable_http", cfg.Transport)
|
|
require.Equal(t, "none", cfg.AuthType)
|
|
require.Equal(t, "default_off", cfg.Availability)
|
|
require.True(t, cfg.Enabled)
|
|
require.Empty(t, cfg.ToolAllowList)
|
|
require.Empty(t, cfg.ToolDenyList)
|
|
require.NotEmpty(t, cfg.Slug)
|
|
require.NotEmpty(t, cfg.Url)
|
|
|
|
// Overrides.
|
|
cfg2 := dbgen.MCPServerConfig(t, db, database.MCPServerConfig{
|
|
DisplayName: "Custom MCP",
|
|
Slug: "custom-mcp",
|
|
Url: "https://custom.example.com",
|
|
AuthType: "oauth2",
|
|
AllowInPlanMode: true,
|
|
})
|
|
require.Equal(t, "Custom MCP", cfg2.DisplayName)
|
|
require.Equal(t, "custom-mcp", cfg2.Slug)
|
|
require.Equal(t, "https://custom.example.com", cfg2.Url)
|
|
require.Equal(t, "oauth2", cfg2.AuthType)
|
|
require.True(t, cfg2.AllowInPlanMode)
|
|
})
|
|
}
|
|
|
|
func must[T any](value T, err error) T {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return value
|
|
}
|