mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
13768 lines
461 KiB
Go
13768 lines
461 KiB
Go
package database_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbfake"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/coderd/database/migrations"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/provisionersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestGetDeploymentWorkspaceAgentStats(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
stats, err := db.GetDeploymentWorkspaceAgentStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(2), stats.SessionCountVSCode)
|
|
})
|
|
|
|
t.Run("GroupsByAgentID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
insertTime := dbtime.Now()
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Second),
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
// Ensure this stat is newer!
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
stats, err := db.GetDeploymentWorkspaceAgentStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(1), stats.SessionCountVSCode)
|
|
})
|
|
}
|
|
|
|
func TestGetDeploymentWorkspaceAgentUsageStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
// Old stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountSSH: 4,
|
|
SessionCountVSCode: 3,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
SessionCountReconnectingPTY: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
stats, err := db.GetDeploymentWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(1), stats.SessionCountVSCode)
|
|
require.Equal(t, int64(1), stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), stats.SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats.SessionCountJetBrains)
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 3,
|
|
RxBytes: 4,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
|
|
stats, err := db.GetDeploymentWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(3), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(4), stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(0), stats.SessionCountVSCode)
|
|
require.Equal(t, int64(0), stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), stats.SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats.SessionCountJetBrains)
|
|
})
|
|
}
|
|
|
|
func TestGetEligibleProvisionerDaemonsByProvisionerJobIDs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoJobsReturnsEmpty", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{})
|
|
require.NoError(t, err)
|
|
require.Empty(t, daemons)
|
|
})
|
|
|
|
t.Run("MatchesProvisionerType", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
matchingDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, matchingDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesOrganizationScope", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
provisionersdk.TagOwner: "",
|
|
},
|
|
})
|
|
|
|
orgDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "org-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
provisionersdk.TagOwner: "",
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "user-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeUser,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, orgDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesMultipleProvisioners", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-1",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemon2 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-2",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-3",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
|
|
daemonIDs := []uuid.UUID{daemons[0].ProvisionerDaemon.ID, daemons[1].ProvisionerDaemon.ID}
|
|
require.ElementsMatch(t, []uuid.UUID{daemon1.ID, daemon2.ID}, daemonIDs)
|
|
})
|
|
}
|
|
|
|
func TestGetProvisionerDaemonsWithStatusByOrganization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoDaemonsInOrgReturnsEmpty", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
otherOrg := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: otherOrg.ID,
|
|
})
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, daemons)
|
|
})
|
|
|
|
t.Run("MatchesProvisionerIDs", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
matchingDaemon0 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon0",
|
|
OrganizationID: org.ID,
|
|
})
|
|
matchingDaemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon1",
|
|
OrganizationID: org.ID,
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
IDs: []uuid.UUID{matchingDaemon0.ID, matchingDaemon1.ID},
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
if daemons[0].ProvisionerDaemon.ID != matchingDaemon0.ID {
|
|
daemons[0], daemons[1] = daemons[1], daemons[0]
|
|
}
|
|
require.Equal(t, matchingDaemon0.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, matchingDaemon1.ID, daemons[1].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesTags", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
fooDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"foo": "bar",
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "baz-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"baz": "qux",
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{"foo": "bar"},
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, fooDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("UsesStaleInterval", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
daemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "stale-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
daemon2 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "idle-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
|
|
if daemons[0].ProvisionerDaemon.ID != daemon1.ID {
|
|
daemons[0], daemons[1] = daemons[1], daemons[0]
|
|
}
|
|
require.Equal(t, daemon1.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, daemon2.ID, daemons[1].ProvisionerDaemon.ID)
|
|
require.Equal(t, database.ProvisionerDaemonStatusOffline, daemons[0].Status)
|
|
require.Equal(t, database.ProvisionerDaemonStatusIdle, daemons[1].Status)
|
|
})
|
|
|
|
t.Run("ExcludeOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
fooDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
|
|
require.Equal(t, fooDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, database.ProvisionerDaemonStatusIdle, daemons[0].Status)
|
|
})
|
|
|
|
t.Run("IncludeOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"foo": "bar",
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "bar-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 3)
|
|
|
|
statusCounts := make(map[database.ProvisionerDaemonStatus]int)
|
|
for _, daemon := range daemons {
|
|
statusCounts[daemon.Status]++
|
|
}
|
|
|
|
require.Equal(t, 2, statusCounts[database.ProvisionerDaemonStatusIdle])
|
|
require.Equal(t, 1, statusCounts[database.ProvisionerDaemonStatusOffline])
|
|
})
|
|
|
|
t.Run("MatchesStatuses", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
type testCase struct {
|
|
name string
|
|
statuses []database.ProvisionerDaemonStatus
|
|
expectedNum int
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "Get idle and offline",
|
|
statuses: []database.ProvisionerDaemonStatus{
|
|
database.ProvisionerDaemonStatusOffline,
|
|
database.ProvisionerDaemonStatusIdle,
|
|
},
|
|
expectedNum: 2,
|
|
},
|
|
{
|
|
name: "Get offline",
|
|
statuses: []database.ProvisionerDaemonStatus{
|
|
database.ProvisionerDaemonStatusOffline,
|
|
},
|
|
expectedNum: 1,
|
|
},
|
|
// Offline daemons should not be included without Offline param
|
|
{
|
|
name: "Get idle - empty statuses",
|
|
statuses: []database.ProvisionerDaemonStatus{},
|
|
expectedNum: 1,
|
|
},
|
|
{
|
|
name: "Get idle - nil statuses",
|
|
statuses: nil,
|
|
expectedNum: 1,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
//nolint:tparallel,paralleltest
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Statuses: tc.statuses,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, tc.expectedNum)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("FilterByMaxAge", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(45 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(45 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "bar-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(25 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(25 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
type testCase struct {
|
|
name string
|
|
maxAge sql.NullInt64
|
|
expectedNum int
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "Max age 1 hour",
|
|
maxAge: sql.NullInt64{Int64: time.Hour.Milliseconds(), Valid: true},
|
|
expectedNum: 2,
|
|
},
|
|
{
|
|
name: "Max age 30 minutes",
|
|
maxAge: sql.NullInt64{Int64: (30 * time.Minute).Milliseconds(), Valid: true},
|
|
expectedNum: 1,
|
|
},
|
|
{
|
|
name: "Max age 15 minutes",
|
|
maxAge: sql.NullInt64{Int64: (15 * time.Minute).Milliseconds(), Valid: true},
|
|
expectedNum: 0,
|
|
},
|
|
{
|
|
name: "No max age",
|
|
maxAge: sql.NullInt64{Valid: false},
|
|
expectedNum: 2,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
//nolint:tparallel,paralleltest
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 60 * time.Minute.Milliseconds(),
|
|
MaxAgeMs: tc.maxAge,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, tc.expectedNum)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentUsageStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
agentID1 := uuid.New()
|
|
agentID2 := uuid.New()
|
|
workspaceID1 := uuid.New()
|
|
workspaceID2 := uuid.New()
|
|
templateID1 := uuid.New()
|
|
templateID2 := uuid.New()
|
|
userID1 := uuid.New()
|
|
userID2 := uuid.New()
|
|
|
|
// Old workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
TxBytes: 2,
|
|
RxBytes: 2,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 4,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 2 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
TxBytes: 4,
|
|
RxBytes: 8,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
TxBytes: 2,
|
|
RxBytes: 3,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 4,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
reqTime := dbtime.Now().Add(-time.Hour)
|
|
stats, err := db.GetWorkspaceAgentUsageStats(ctx, reqTime)
|
|
require.NoError(t, err)
|
|
|
|
ws1Stats, ws2Stats := stats[0], stats[1]
|
|
if ws1Stats.WorkspaceID != workspaceID1 {
|
|
ws1Stats, ws2Stats = ws2Stats, ws1Stats
|
|
}
|
|
require.Equal(t, int64(3), ws1Stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(3), ws1Stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(1), ws1Stats.SessionCountVSCode)
|
|
require.Equal(t, int64(1), ws1Stats.SessionCountJetBrains)
|
|
require.Equal(t, int64(0), ws1Stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), ws1Stats.SessionCountReconnectingPTY)
|
|
|
|
require.Equal(t, int64(6), ws2Stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(11), ws2Stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(1), ws2Stats.SessionCountSSH)
|
|
require.Equal(t, int64(1), ws2Stats.SessionCountJetBrains)
|
|
require.Equal(t, int64(0), ws2Stats.SessionCountVSCode)
|
|
require.Equal(t, int64(0), ws2Stats.SessionCountReconnectingPTY)
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
agentID := uuid.New()
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 3,
|
|
RxBytes: 4,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 1)
|
|
require.Equal(t, int64(3), stats[0].WorkspaceTxBytes)
|
|
require.Equal(t, int64(4), stats[0].WorkspaceRxBytes)
|
|
require.Equal(t, int64(0), stats[0].SessionCountVSCode)
|
|
require.Equal(t, int64(0), stats[0].SessionCountSSH)
|
|
require.Equal(t, int64(0), stats[0].SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats[0].SessionCountJetBrains)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentUsageStatsAndLabels(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
insertTime := dbtime.Now()
|
|
|
|
// Insert user, agent, template, workspace
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job1 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource1.ID,
|
|
})
|
|
template1 := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
workspace1 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template1.ID,
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
job2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job2.ID,
|
|
})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource2.ID,
|
|
})
|
|
template2 := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
workspace2 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template2.ID,
|
|
})
|
|
|
|
// Old workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
TxBytes: 2,
|
|
RxBytes: 2,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 4,
|
|
SessionCountSSH: 3,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountReconnectingPTY: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 2 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
TxBytes: 4,
|
|
RxBytes: 8,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStatsAndLabels(ctx, insertTime.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 2)
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user1.Username,
|
|
AgentName: agent1.Name,
|
|
WorkspaceName: workspace1.Name,
|
|
TxBytes: 3,
|
|
RxBytes: 3,
|
|
SessionCountJetBrains: 1,
|
|
SessionCountReconnectingPTY: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user2.Username,
|
|
AgentName: agent2.Name,
|
|
WorkspaceName: workspace2.Name,
|
|
RxBytes: 8,
|
|
TxBytes: 4,
|
|
SessionCountVSCode: 1,
|
|
SessionCountSSH: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
insertTime := dbtime.Now()
|
|
// Insert user, agent, template, workspace
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent.ID,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateID: template.ID,
|
|
UserID: user.ID,
|
|
RxBytes: 4,
|
|
TxBytes: 5,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStatsAndLabels(ctx, insertTime.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 1)
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user.Username,
|
|
AgentName: agent.Name,
|
|
WorkspaceName: workspace.Name,
|
|
RxBytes: 4,
|
|
TxBytes: 5,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{
|
|
RBACRoles: []string{rbac.RoleOwner().String()},
|
|
})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
|
|
pendingID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusPending,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: pendingID,
|
|
CreateAgent: true,
|
|
})
|
|
failedID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusFailed,
|
|
CreateWorkspace: true,
|
|
CreateAgent: true,
|
|
WorkspaceID: failedID,
|
|
})
|
|
succeededID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
WorkspaceTransition: database.WorkspaceTransitionStart,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: succeededID,
|
|
CreateAgent: true,
|
|
ExtraAgents: 1,
|
|
ExtraBuilds: 2,
|
|
})
|
|
deletedID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
WorkspaceTransition: database.WorkspaceTransitionDelete,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: deletedID,
|
|
CreateAgent: false,
|
|
})
|
|
|
|
ownerCheckFn := func(ownerRows []database.GetWorkspacesAndAgentsByOwnerIDRow) {
|
|
require.Len(t, ownerRows, 4)
|
|
for _, row := range ownerRows {
|
|
switch row.ID {
|
|
case pendingID:
|
|
require.Len(t, row.Agents, 1)
|
|
require.Equal(t, database.ProvisionerJobStatusPending, row.JobStatus)
|
|
case failedID:
|
|
require.Len(t, row.Agents, 1)
|
|
require.Equal(t, database.ProvisionerJobStatusFailed, row.JobStatus)
|
|
case succeededID:
|
|
require.Len(t, row.Agents, 2)
|
|
require.Equal(t, database.ProvisionerJobStatusSucceeded, row.JobStatus)
|
|
require.Equal(t, database.WorkspaceTransitionStart, row.Transition)
|
|
case deletedID:
|
|
require.Len(t, row.Agents, 0)
|
|
require.Equal(t, database.ProvisionerJobStatusSucceeded, row.JobStatus)
|
|
require.Equal(t, database.WorkspaceTransitionDelete, row.Transition)
|
|
default:
|
|
t.Fatalf("unexpected workspace ID: %s", row.ID)
|
|
}
|
|
}
|
|
}
|
|
t.Run("sqlQuerier", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
userSubject, _, err := httpmw.UserRBACSubject(ctx, db, user.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedUser, err := authorizer.Prepare(ctx, userSubject, policy.ActionRead, rbac.ResourceWorkspace.Type)
|
|
require.NoError(t, err)
|
|
userCtx := dbauthz.As(ctx, userSubject)
|
|
userRows, err := db.GetAuthorizedWorkspacesAndAgentsByOwnerID(userCtx, owner.ID, preparedUser)
|
|
require.NoError(t, err)
|
|
require.Len(t, userRows, 0)
|
|
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceWorkspace.Type)
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
ownerRows, err := db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ownerCtx, owner.ID, preparedOwner)
|
|
require.NoError(t, err)
|
|
ownerCheckFn(ownerRows)
|
|
})
|
|
|
|
t.Run("dbauthz", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
userSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, user.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
userCtx := dbauthz.As(ctx, userSubject)
|
|
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
|
|
userRows, err := authzdb.GetWorkspacesAndAgentsByOwnerID(userCtx, owner.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, userRows, 0)
|
|
|
|
ownerRows, err := authzdb.GetWorkspacesAndAgentsByOwnerID(ownerCtx, owner.ID)
|
|
require.NoError(t, err)
|
|
ownerCheckFn(ownerRows)
|
|
})
|
|
}
|
|
|
|
func TestGetAuthorizedChats(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
|
|
// Create users with different roles.
|
|
owner := dbgen.User(t, db, database.User{
|
|
RBACRoles: []string{rbac.RoleOwner().String()},
|
|
})
|
|
member := dbgen.User(t, db, database.User{})
|
|
secondMember := dbgen.User(t, db, database.User{})
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: member.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: secondMember.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
|
|
// Create FK dependencies: a chat provider and model config.
|
|
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
})
|
|
modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
IsDefault: true,
|
|
CompressionThreshold: 80,
|
|
})
|
|
|
|
// Create 3 chats owned by owner.
|
|
for i := range 3 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("owner chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
// Create 2 chats owned by member.
|
|
for i := range 2 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: member.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("member chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
t.Run("sqlQuerier", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Member should only see their own 2 chats.
|
|
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberRows, 2)
|
|
for _, row := range memberRows {
|
|
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
|
}
|
|
|
|
// Owner should see at least the 5 pre-created chats (site-wide
|
|
// access). Parallel subtests may add more, so use GreaterOrEqual.
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(ownerRows), 5)
|
|
|
|
// secondMember has no chats and should see 0.
|
|
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
|
|
require.NoError(t, err)
|
|
require.Len(t, secondRows, 0)
|
|
|
|
// Org admin should NOT see other users' chats when they are
|
|
// in a different org than the chat owner.
|
|
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, orgs)
|
|
orgAdmin := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: orgAdmin.ID,
|
|
OrganizationID: orgs[0].ID,
|
|
Roles: []string{rbac.RoleOrgAdmin()},
|
|
})
|
|
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
|
|
require.NoError(t, err)
|
|
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
|
|
|
|
// Org admin in SAME org should see all chats in that org.
|
|
sameOrgAdmin := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: sameOrgAdmin.ID,
|
|
OrganizationID: org.ID,
|
|
Roles: []string{rbac.RoleOrgAdmin()},
|
|
})
|
|
sameOrgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, sameOrgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedSameOrgAdmin, err := authorizer.Prepare(ctx, sameOrgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
sameOrgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSameOrgAdmin)
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(sameOrgAdminRows), 5, "same-org admin should see all chats in their org")
|
|
|
|
// OwnerID filter: member queries their own chats.
|
|
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
OwnerID: member.ID,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberFilterSelf, 2)
|
|
|
|
// OwnerID filter: member queries owner's chats → sees 0.
|
|
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberFilterOwner, 0)
|
|
})
|
|
|
|
t.Run("dbauthz", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
// As member: should see only own 2 chats.
|
|
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
memberCtx := dbauthz.As(ctx, memberSubject)
|
|
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, memberRows, 2)
|
|
for _, row := range memberRows {
|
|
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
|
}
|
|
|
|
// As owner: should see at least the 5 pre-created chats.
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(ownerRows), 5)
|
|
|
|
// As secondMember: should see 0 chats.
|
|
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
secondCtx := dbauthz.As(ctx, secondSubject)
|
|
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, secondRows, 0)
|
|
})
|
|
|
|
t.Run("pagination", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Use a dedicated user for pagination to avoid interference
|
|
// with the other parallel subtests.
|
|
paginationUser := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: paginationUser.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
for i := range 7 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: paginationUser.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("pagination chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
|
|
// Fetch first page with limit=2.
|
|
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
LimitOpt: 2,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, page1, 2)
|
|
for _, row := range page1 {
|
|
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
|
}
|
|
|
|
// Fetch remaining pages and collect all chat IDs.
|
|
allIDs := make(map[uuid.UUID]struct{})
|
|
for _, row := range page1 {
|
|
allIDs[row.Chat.ID] = struct{}{}
|
|
}
|
|
offset := int32(2)
|
|
for {
|
|
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
LimitOpt: 2,
|
|
OffsetOpt: offset,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
for _, row := range page {
|
|
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
|
allIDs[row.Chat.ID] = struct{}{}
|
|
}
|
|
if len(page) < 2 {
|
|
break
|
|
}
|
|
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
|
|
}
|
|
|
|
// All 7 member chats should be accounted for with no leakage.
|
|
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
|
|
})
|
|
}
|
|
|
|
func TestInsertWorkspaceAgentLogs(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
ctx := context.Background()
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
source := dbgen.WorkspaceAgentLogSource(t, db, database.WorkspaceAgentLogSource{
|
|
WorkspaceAgentID: agent.ID,
|
|
})
|
|
logs, err := db.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
|
|
AgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
Output: []string{"first"},
|
|
Level: []database.LogLevel{database.LogLevelInfo},
|
|
LogSourceID: source.ID,
|
|
// 1 MB is the max
|
|
OutputLength: 1 << 20,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), logs[0].ID)
|
|
|
|
_, err = db.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
|
|
AgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
Output: []string{"second"},
|
|
Level: []database.LogLevel{database.LogLevelInfo},
|
|
LogSourceID: source.ID,
|
|
OutputLength: 1,
|
|
})
|
|
require.True(t, database.IsWorkspaceAgentLogsLimitError(err))
|
|
}
|
|
|
|
func TestProxyByHostname(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
// Insert a bunch of different proxies.
|
|
proxies := []struct {
|
|
name string
|
|
accessURL string
|
|
wildcardHostname string
|
|
}{
|
|
{
|
|
name: "one",
|
|
accessURL: "https://one.coder.com",
|
|
wildcardHostname: "*.wildcard.one.coder.com",
|
|
},
|
|
{
|
|
name: "two",
|
|
accessURL: "https://two.coder.com",
|
|
wildcardHostname: "*--suffix.two.coder.com",
|
|
},
|
|
}
|
|
for _, p := range proxies {
|
|
dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{
|
|
Name: p.name,
|
|
Url: p.accessURL,
|
|
WildcardHostname: p.wildcardHostname,
|
|
})
|
|
}
|
|
|
|
cases := []struct {
|
|
name string
|
|
testHostname string
|
|
allowAccessURL bool
|
|
allowWildcardHost bool
|
|
matchProxyName string
|
|
}{
|
|
{
|
|
name: "NoMatch",
|
|
testHostname: "test.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "MatchAccessURL",
|
|
testHostname: "one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "one",
|
|
},
|
|
{
|
|
name: "MatchWildcard",
|
|
testHostname: "something.wildcard.one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "one",
|
|
},
|
|
{
|
|
name: "MatchSuffix",
|
|
testHostname: "something--suffix.two.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "two",
|
|
},
|
|
{
|
|
name: "ValidateHostname/1",
|
|
testHostname: ".*ne.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "ValidateHostname/2",
|
|
testHostname: "https://one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "ValidateHostname/3",
|
|
testHostname: "one.coder.com:8080/hello",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "IgnoreAccessURLMatch",
|
|
testHostname: "one.coder.com",
|
|
allowAccessURL: false,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "IgnoreWildcardMatch",
|
|
testHostname: "hi.wildcard.one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: false,
|
|
matchProxyName: "",
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
proxy, err := db.GetWorkspaceProxyByHostname(context.Background(), database.GetWorkspaceProxyByHostnameParams{
|
|
Hostname: c.testHostname,
|
|
AllowAccessUrl: c.allowAccessURL,
|
|
AllowWildcardHostname: c.allowWildcardHost,
|
|
})
|
|
if c.matchProxyName == "" {
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
require.Empty(t, proxy)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, proxy)
|
|
require.Equal(t, c.matchProxyName, proxy.Name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultProxy(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
depID := uuid.NewString()
|
|
err = db.InsertDeploymentID(ctx, depID)
|
|
require.NoError(t, err, "insert deployment id")
|
|
|
|
// Fetch empty proxy values
|
|
defProxy, err := db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
|
|
require.Equal(t, defProxy.DisplayName, "Default")
|
|
require.Equal(t, defProxy.IconURL, "/emojis/1f3e1.png")
|
|
|
|
// Set the proxy values
|
|
args := database.UpsertDefaultProxyParams{
|
|
DisplayName: "displayname",
|
|
IconURL: "/icon.png",
|
|
}
|
|
err = db.UpsertDefaultProxy(ctx, args)
|
|
require.NoError(t, err, "insert def proxy")
|
|
|
|
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
|
require.Equal(t, defProxy.IconURL, args.IconURL)
|
|
|
|
// Upsert values
|
|
args = database.UpsertDefaultProxyParams{
|
|
DisplayName: "newdisplayname",
|
|
IconURL: "/newicon.png",
|
|
}
|
|
err = db.UpsertDefaultProxy(ctx, args)
|
|
require.NoError(t, err, "upsert def proxy")
|
|
|
|
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
|
require.Equal(t, defProxy.IconURL, args.IconURL)
|
|
|
|
// Ensure other site configs are the same
|
|
found, err := db.GetDeploymentID(ctx)
|
|
require.NoError(t, err, "get deployment id")
|
|
require.Equal(t, depID, found)
|
|
}
|
|
|
|
func TestQueuePosition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
jobCount := 10
|
|
jobs := []database.ProvisionerJob{}
|
|
jobIDs := []uuid.UUID{}
|
|
for i := 0; i < jobCount; i++ {
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{},
|
|
})
|
|
jobs = append(jobs, job)
|
|
jobIDs = append(jobIDs, job.ID)
|
|
|
|
// We need a slight amount of time between each insertion to ensure that
|
|
// the queue position is correct... it's sorted by `created_at`.
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
// Ensure the `tags` field is NOT NULL for the default provisioner;
|
|
// otherwise, it won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, jobCount)
|
|
sort.Slice(queued, func(i, j int) bool {
|
|
return queued[i].QueuePosition < queued[j].QueuePosition
|
|
})
|
|
// Ensure that the queue positions are correct based on insertion ID!
|
|
for index, job := range queued {
|
|
require.Equal(t, job.QueuePosition, int64(index+1))
|
|
require.Equal(t, job.ProvisionerJob.ID, jobs[index].ID)
|
|
}
|
|
|
|
job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
},
|
|
Types: database.AllProvisionerTypeValues(),
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
ProvisionerTags: json.RawMessage("{}"),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, jobs[0].ID, job.ID)
|
|
|
|
queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, jobCount)
|
|
sort.Slice(queued, func(i, j int) bool {
|
|
return queued[i].QueuePosition < queued[j].QueuePosition
|
|
})
|
|
// Ensure that queue positions are updated now that the first job has been acquired!
|
|
for index, job := range queued {
|
|
if index == 0 {
|
|
require.Equal(t, job.QueuePosition, int64(0))
|
|
continue
|
|
}
|
|
require.Equal(t, job.QueuePosition, int64(index))
|
|
require.Equal(t, job.ProvisionerJob.ID, jobs[index].ID)
|
|
}
|
|
}
|
|
|
|
func TestAcquireProvisionerJob(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("HumanInitiatedJobsFirst", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
ctx = testutil.Context(t, testutil.WaitMedium)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
_ = dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{}) // Required for queue position
|
|
now = dbtime.Now()
|
|
numJobs = 10
|
|
humanIDs = make([]uuid.UUID, 0, numJobs/2)
|
|
prebuildIDs = make([]uuid.UUID, 0, numJobs/2)
|
|
)
|
|
|
|
// Given: a number of jobs in the queue, with prebuilds and non-prebuilds interleaved
|
|
for idx := range numJobs {
|
|
var initiator uuid.UUID
|
|
if idx%2 == 0 {
|
|
initiator = database.PrebuildsSystemUserID
|
|
} else {
|
|
initiator = uuid.MustParse("c0dec0de-c0de-c0de-c0de-c0dec0dec0de")
|
|
}
|
|
pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.MustParse(fmt.Sprintf("00000000-0000-0000-0000-00000000000%x", idx+1)),
|
|
CreatedAt: time.Now().Add(-time.Second * time.Duration(idx)),
|
|
UpdatedAt: time.Now().Add(-time.Second * time.Duration(idx)),
|
|
InitiatorID: initiator,
|
|
OrganizationID: org.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: uuid.New(),
|
|
Input: json.RawMessage(`{}`),
|
|
Tags: database.StringMap{},
|
|
TraceMetadata: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
// We expected prebuilds to be acquired after human-initiated jobs.
|
|
if initiator == database.PrebuildsSystemUserID {
|
|
prebuildIDs = append([]uuid.UUID{pj.ID}, prebuildIDs...)
|
|
} else {
|
|
humanIDs = append([]uuid.UUID{pj.ID}, humanIDs...)
|
|
}
|
|
t.Logf("created job id=%q initiator=%q created_at=%q", pj.ID.String(), pj.InitiatorID.String(), pj.CreatedAt.String())
|
|
}
|
|
|
|
expectedIDs := append(humanIDs, prebuildIDs...) //nolint:gocritic // not the same slice
|
|
|
|
// When: we query the queue positions for the jobs
|
|
qjs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: expectedIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, qjs, numJobs)
|
|
// Ensure the jobs are sorted by queue position.
|
|
sort.Slice(qjs, func(i, j int) bool {
|
|
return qjs[i].QueuePosition < qjs[j].QueuePosition
|
|
})
|
|
|
|
// Then: the queue positions for the jobs should indicate the order in which
|
|
// they will be acquired, with human-initiated jobs first.
|
|
for idx, qj := range qjs {
|
|
t.Logf("queued job %d/%d id=%q initiator=%q created_at=%q queue_position=%d", idx+1, numJobs, qj.ProvisionerJob.ID.String(), qj.ProvisionerJob.InitiatorID.String(), qj.ProvisionerJob.CreatedAt.String(), qj.QueuePosition)
|
|
require.Equal(t, expectedIDs[idx].String(), qj.ProvisionerJob.ID.String(), "job %d/%d should match expected id", idx+1, numJobs)
|
|
require.Equal(t, int64(idx+1), qj.QueuePosition, "job %d/%d should have queue position %d", idx+1, numJobs, idx+1)
|
|
}
|
|
|
|
// When: the jobs are acquired
|
|
// Then: human-initiated jobs are prioritized first.
|
|
for idx := range numJobs {
|
|
acquired, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
ProvisionerTags: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, expectedIDs[idx].String(), acquired.ID.String(), "acquired job %d/%d with initiator %q", idx+1, numJobs, acquired.InitiatorID.String())
|
|
t.Logf("acquired job id=%q initiator=%q created_at=%q", acquired.ID.String(), acquired.InitiatorID.String(), acquired.CreatedAt.String())
|
|
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
|
ID: acquired.ID,
|
|
UpdatedAt: now,
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
Error: sql.NullString{},
|
|
ErrorCode: sql.NullString{},
|
|
})
|
|
require.NoError(t, err, "mark job %d/%d as complete", idx+1, numJobs)
|
|
}
|
|
})
|
|
|
|
t.Run("SkipsCanceledPendingJobs", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
ctx = testutil.Context(t, testutil.WaitMedium)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
now = dbtime.Now()
|
|
)
|
|
|
|
// Insert a pending job (started_at is NULL).
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
InitiatorID: uuid.New(),
|
|
OrganizationID: org.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: uuid.New(),
|
|
Input: json.RawMessage(`{}`),
|
|
Tags: database.StringMap{},
|
|
TraceMetadata: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Cancel it while still pending. In production (workspacebuilds.go), canceling
|
|
// a pending build sets completed_at but leaves started_at NULL since no
|
|
// provisioner ever started the job.
|
|
err = db.UpdateProvisionerJobWithCancelByID(ctx, database.UpdateProvisionerJobWithCancelByIDParams{
|
|
ID: job.ID,
|
|
CanceledAt: sql.NullTime{Time: now, Valid: true},
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// AcquireProvisionerJob should skip this job since it's already completed.
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{Time: now, Valid: true},
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
ProvisionerTags: json.RawMessage(`{}`),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
})
|
|
}
|
|
|
|
func TestUserLastSeenFilter(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
t.Run("Before", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
now := dbtime.Now()
|
|
|
|
yesterday := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now.Add(time.Hour * -25),
|
|
})
|
|
today := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now,
|
|
})
|
|
lastWeek := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now.Add((time.Hour * -24 * 7) + (-1 * time.Hour)),
|
|
})
|
|
|
|
beforeToday, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour * -24),
|
|
})
|
|
require.NoError(t, err)
|
|
database.ConvertUserRows(beforeToday)
|
|
|
|
requireUsersMatch(t, []database.User{yesterday, lastWeek}, beforeToday, "before today")
|
|
|
|
justYesterday, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour * -24),
|
|
LastSeenAfter: now.Add(time.Hour * -24 * 2),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{yesterday}, justYesterday, "just yesterday")
|
|
|
|
all, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{today, yesterday, lastWeek}, all, "all")
|
|
|
|
allAfterLastWeek, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenAfter: now.Add(time.Hour * -24 * 7),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{today, yesterday}, allAfterLastWeek, "after last week")
|
|
})
|
|
}
|
|
|
|
func TestGetUsers_IncludeSystem(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
includeSystem bool
|
|
wantSystemUser bool
|
|
}{
|
|
{
|
|
name: "include system users",
|
|
includeSystem: true,
|
|
wantSystemUser: true,
|
|
},
|
|
{
|
|
name: "exclude system users",
|
|
includeSystem: false,
|
|
wantSystemUser: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Given: a system user
|
|
// postgres: introduced by migration coderd/database/migrations/00030*_system_user.up.sql
|
|
db, _ := dbtestutil.NewDB(t)
|
|
other := dbgen.User(t, db, database.User{})
|
|
users, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
IncludeSystem: tt.includeSystem,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Should always find the regular user
|
|
foundRegularUser := false
|
|
foundSystemUser := false
|
|
|
|
for _, u := range users {
|
|
if u.IsSystem {
|
|
foundSystemUser = true
|
|
require.Equal(t, database.PrebuildsSystemUserID, u.ID)
|
|
} else {
|
|
foundRegularUser = true
|
|
require.Equalf(t, other.ID.String(), u.ID.String(), "found unexpected regular user")
|
|
}
|
|
}
|
|
|
|
require.True(t, foundRegularUser, "regular user should always be found")
|
|
require.Equal(t, tt.wantSystemUser, foundSystemUser, "system user presence should match includeSystem setting")
|
|
require.Equal(t, tt.wantSystemUser, len(users) == 2, "should have 2 users when including system user, 1 otherwise")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateSystemUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// TODO (sasswart): We've disabled the protection that prevents updates to system users
|
|
// while we reassess the mechanism to do so. Rather than skip the test, we've just inverted
|
|
// the assertions to ensure that the behavior is as desired.
|
|
// Once we've re-enabeld the system user protection, we'll revert the assertions.
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Given: a system user introduced by migration coderd/database/migrations/00030*_system_user.up.sql
|
|
db, _ := dbtestutil.NewDB(t)
|
|
users, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
IncludeSystem: true,
|
|
})
|
|
require.NoError(t, err)
|
|
var systemUser database.GetUsersRow
|
|
for _, u := range users {
|
|
if u.IsSystem {
|
|
systemUser = u
|
|
}
|
|
}
|
|
require.NotNil(t, systemUser)
|
|
|
|
// When: attempting to update a system user's name.
|
|
_, err = db.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
|
|
ID: systemUser.ID,
|
|
Email: systemUser.Email,
|
|
Username: systemUser.Username,
|
|
AvatarURL: systemUser.AvatarURL,
|
|
Name: "not prebuilds",
|
|
})
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
|
|
// When: attempting to delete a system user.
|
|
err = db.UpdateUserDeletedByID(ctx, systemUser.ID)
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
|
|
// When: attempting to update a user's roles.
|
|
_, err = db.UpdateUserRoles(ctx, database.UpdateUserRolesParams{
|
|
ID: systemUser.ID,
|
|
GrantedRoles: []string{rbac.RoleAuditor().String()},
|
|
})
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestInsertUserServiceAccountConstraints(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Happy path: should succeed.
|
|
t.Run("ServiceAccountWithEmptyEmailAndLoginNone", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypeNone,
|
|
ID: uuid.New(),
|
|
Username: "sa-ok",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, user.IsServiceAccount)
|
|
require.Empty(t, user.Email)
|
|
})
|
|
|
|
// Service account with a non-empty email should be rejected
|
|
// by the users_email_not_empty constraint.
|
|
t.Run("ServiceAccountWithNonEmptyEmail", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "sa@coder.com",
|
|
LoginType: database.LoginTypeNone,
|
|
ID: uuid.New(),
|
|
Username: "sa-with-email",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersEmailNotEmpty))
|
|
})
|
|
|
|
// A non-service-account with empty email should be rejected
|
|
// by the users_email_not_empty constraint.
|
|
t.Run("RegularUserWithEmptyEmail", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypePassword,
|
|
ID: uuid.New(),
|
|
Username: "regular-no-email",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: false,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersEmailNotEmpty))
|
|
})
|
|
|
|
// Service account with login_type!=none should be rejected
|
|
// by the users_service_account_login_type constraint.
|
|
t.Run("ServiceAccountWithPasswordLoginType", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypePassword,
|
|
ID: uuid.New(),
|
|
Username: "sa-with-password",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersServiceAccountLoginType))
|
|
})
|
|
}
|
|
|
|
func TestGetActiveUserCount(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Seed users: 2 active humans, 1 active service account,
|
|
// 1 dormant, 1 deleted. Only the 2 active humans should
|
|
// be counted for license seat purposes.
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
})
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
})
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
IsServiceAccount: true,
|
|
})
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusDormant,
|
|
})
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
Deleted: true,
|
|
})
|
|
|
|
count, err := db.GetActiveUserCount(ctx, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(2), count)
|
|
}
|
|
|
|
func TestUserChangeLoginType(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
alice := dbgen.User(t, db, database.User{
|
|
LoginType: database.LoginTypePassword,
|
|
})
|
|
bob := dbgen.User(t, db, database.User{
|
|
LoginType: database.LoginTypePassword,
|
|
})
|
|
bobExpPass := bob.HashedPassword
|
|
require.NotEmpty(t, alice.HashedPassword, "hashed password should not start empty")
|
|
require.NotEmpty(t, bob.HashedPassword, "hashed password should not start empty")
|
|
|
|
alice, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
|
|
NewLoginType: database.LoginTypeOIDC,
|
|
UserID: alice.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Empty(t, alice.HashedPassword, "hashed password should be empty")
|
|
|
|
// First check other users are not affected
|
|
bob, err = db.GetUserByID(ctx, bob.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
|
|
|
|
// Then check password -> password is a noop
|
|
bob, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
|
|
NewLoginType: database.LoginTypePassword,
|
|
UserID: bob.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
bob, err = db.GetUserByID(ctx, bob.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
|
|
}
|
|
|
|
func TestDefaultOrg(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
// Should start with the default org
|
|
all, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, all, 1)
|
|
require.True(t, all[0].IsDefault, "first org should always be default")
|
|
}
|
|
|
|
func TestAuditLogDefaultLimit(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
for i := 0; i < 110; i++ {
|
|
dbgen.AuditLog(t, db, database.AuditLog{})
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
rows, err := db.GetAuditLogsOffset(ctx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// The length should match the default limit of the SQL query.
|
|
// Updating the sql query requires changing the number below to match.
|
|
require.Len(t, rows, 100)
|
|
}
|
|
|
|
func TestAuditLogCount(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
dbgen.AuditLog(t, db, database.AuditLog{})
|
|
|
|
count, err := db.CountAuditLogs(ctx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), count)
|
|
}
|
|
|
|
func TestWorkspaceQuotas(t *testing.T) {
|
|
t.Parallel()
|
|
orgMemberIDs := func(o database.OrganizationMember) uuid.UUID {
|
|
return o.UserID
|
|
}
|
|
groupMemberIDs := func(m database.GroupMember) uuid.UUID {
|
|
return m.UserID
|
|
}
|
|
|
|
t.Run("CorruptedEveryone", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
// Create an extra org as a distraction
|
|
distract := dbgen.Organization(t, db, database.Organization{})
|
|
_, err := db.InsertAllUsersGroup(ctx, distract.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
|
|
QuotaAllowance: 15,
|
|
ID: distract.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create an org with 2 users
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
everyoneGroup, err := db.InsertAllUsersGroup(ctx, org.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Add a quota to the everyone group
|
|
_, err = db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
|
|
QuotaAllowance: 50,
|
|
ID: everyoneGroup.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Add people to the org
|
|
one := dbgen.User(t, db, database.User{})
|
|
two := dbgen.User(t, db, database.User{})
|
|
memOne := dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: one.ID,
|
|
})
|
|
memTwo := dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: two.ID,
|
|
})
|
|
|
|
// Fetch the 'Everyone' group members
|
|
everyoneMembers, err := db.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams{
|
|
GroupID: everyoneGroup.ID,
|
|
IncludeSystem: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.ElementsMatch(t, slice.List(everyoneMembers, groupMemberIDs),
|
|
slice.List([]database.OrganizationMember{memOne, memTwo}, orgMemberIDs))
|
|
|
|
// Check the quota is correct.
|
|
allowance, err := db.GetQuotaAllowanceForUser(ctx, database.GetQuotaAllowanceForUserParams{
|
|
UserID: one.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(50), allowance)
|
|
|
|
// Now try to corrupt the DB
|
|
// Insert rows into the everyone group
|
|
err = db.InsertGroupMember(ctx, database.InsertGroupMemberParams{
|
|
UserID: memOne.UserID,
|
|
GroupID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Ensure allowance remains the same
|
|
allowance, err = db.GetQuotaAllowanceForUser(ctx, database.GetQuotaAllowanceForUserParams{
|
|
UserID: one.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(50), allowance)
|
|
})
|
|
}
|
|
|
|
// TestReadCustomRoles tests the input params returns the correct set of roles.
|
|
func TestReadCustomRoles(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Make a few site roles, and a few org roles
|
|
orgIDs := make([]uuid.UUID, 3)
|
|
for i := range orgIDs {
|
|
orgIDs[i] = uuid.New()
|
|
}
|
|
|
|
allRoles := make([]database.CustomRole, 0)
|
|
siteRoles := make([]database.CustomRole, 0)
|
|
orgRoles := make([]database.CustomRole, 0)
|
|
for i := 0; i < 15; i++ {
|
|
orgID := uuid.NullUUID{
|
|
UUID: orgIDs[i%len(orgIDs)],
|
|
Valid: true,
|
|
}
|
|
if i%4 == 0 {
|
|
// Some should be site wide
|
|
orgID = uuid.NullUUID{}
|
|
}
|
|
|
|
role, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: fmt.Sprintf("role-%d", i),
|
|
OrganizationID: orgID,
|
|
})
|
|
require.NoError(t, err)
|
|
allRoles = append(allRoles, role)
|
|
if orgID.Valid {
|
|
orgRoles = append(orgRoles, role)
|
|
} else {
|
|
siteRoles = append(siteRoles, role)
|
|
}
|
|
}
|
|
|
|
// normalizedRoleName allows for the simple ElementsMatch to work properly.
|
|
normalizedRoleName := func(role database.CustomRole) string {
|
|
return role.Name + ":" + role.OrganizationID.UUID.String()
|
|
}
|
|
|
|
roleToLookup := func(role database.CustomRole) database.NameOrganizationPair {
|
|
return database.NameOrganizationPair{
|
|
Name: role.Name,
|
|
OrganizationID: role.OrganizationID.UUID,
|
|
}
|
|
}
|
|
|
|
testCases := []struct {
|
|
Name string
|
|
Params database.CustomRolesParams
|
|
Match func(role database.CustomRole) bool
|
|
}{
|
|
{
|
|
Name: "NilRoles",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: nil,
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: uuid.UUID{},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
// Empty params should return all roles
|
|
Name: "Empty",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{},
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: uuid.UUID{},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
Name: "Organization",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{},
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: orgIDs[1],
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.OrganizationID.UUID == orgIDs[1]
|
|
},
|
|
},
|
|
{
|
|
Name: "SpecificOrgRole",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID
|
|
},
|
|
},
|
|
{
|
|
Name: "SpecificSiteRole",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
OrganizationID: siteRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID
|
|
},
|
|
},
|
|
{
|
|
Name: "FewSpecificRoles",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: orgRoles[1].Name,
|
|
OrganizationID: orgRoles[1].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
OrganizationID: siteRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
|
|
(role.Name == orgRoles[1].Name && role.OrganizationID.UUID == orgRoles[1].OrganizationID.UUID) ||
|
|
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
|
|
},
|
|
},
|
|
{
|
|
Name: "AllRolesByLookup",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: slice.List(allRoles, roleToLookup),
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
Name: "NotExists",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.New(),
|
|
},
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return false
|
|
},
|
|
},
|
|
{
|
|
Name: "Mixed",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.New(),
|
|
},
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
|
|
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
found, err := db.CustomRoles(ctx, tc.Params)
|
|
require.NoError(t, err)
|
|
filtered := make([]database.CustomRole, 0)
|
|
for _, role := range allRoles {
|
|
if tc.Match(role) {
|
|
filtered = append(filtered, role)
|
|
}
|
|
}
|
|
|
|
a := slice.List(filtered, normalizedRoleName)
|
|
b := slice.List(found, normalizedRoleName)
|
|
require.Equal(t, a, b)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
systemRole, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: "test-system-role",
|
|
DisplayName: "",
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
SitePermissions: database.CustomRolePermissions{},
|
|
OrgPermissions: database.CustomRolePermissions{},
|
|
UserPermissions: database.CustomRolePermissions{},
|
|
MemberPermissions: database.CustomRolePermissions{},
|
|
IsSystem: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
nonSystemRole, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: "test-custom-role",
|
|
DisplayName: "",
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
SitePermissions: database.CustomRolePermissions{},
|
|
OrgPermissions: database.CustomRolePermissions{},
|
|
UserPermissions: database.CustomRolePermissions{},
|
|
MemberPermissions: database.CustomRolePermissions{},
|
|
IsSystem: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.DeleteCustomRole(ctx, database.DeleteCustomRoleParams{
|
|
Name: systemRole.Name,
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.DeleteCustomRole(ctx, database.DeleteCustomRoleParams{
|
|
Name: nonSystemRole.Name,
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
roles, err := db.CustomRoles(ctx, database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: systemRole.Name,
|
|
OrganizationID: org.ID,
|
|
},
|
|
{
|
|
Name: nonSystemRole.Name,
|
|
OrganizationID: org.ID,
|
|
},
|
|
},
|
|
IncludeSystemRoles: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, roles, 1)
|
|
require.Equal(t, systemRole.Name, roles[0].Name)
|
|
require.True(t, roles[0].IsSystem)
|
|
}
|
|
|
|
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
regularUser := dbgen.User(t, db, database.User{})
|
|
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: regularUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: saUser.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
|
|
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
|
|
|
|
// Regular users get the implied organization-member role.
|
|
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
|
|
require.NoError(t, err)
|
|
require.Contains(t, regularRoles.Roles, wantMember)
|
|
require.NotContains(t, regularRoles.Roles, wantSA)
|
|
|
|
// Service accounts get the implied organization-service-account role.
|
|
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
|
|
require.NoError(t, err)
|
|
require.Contains(t, saRoles.Roles, wantSA)
|
|
require.NotContains(t, saRoles.Roles, wantMember)
|
|
}
|
|
|
|
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
|
ID: org.ID,
|
|
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
|
UpdatedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
|
|
|
|
got, err := db.GetOrganizationByID(ctx, org.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
|
|
}
|
|
|
|
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("DeletesAll", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org1 := dbgen.Organization(t, db, database.Organization{})
|
|
org2 := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
owner1 := dbgen.User(t, db, database.User{})
|
|
owner2 := dbgen.User(t, db, database.User{})
|
|
sharedUser := dbgen.User(t, db, database.User{})
|
|
sharedGroup := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: org1.ID,
|
|
})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: owner1.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org2.ID,
|
|
UserID: owner2.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: sharedUser.ID,
|
|
})
|
|
|
|
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: owner1.ID,
|
|
OrganizationID: org1.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
GroupACL: database.WorkspaceACL{
|
|
sharedGroup.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: owner2.ID,
|
|
OrganizationID: org2.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
uuid.NewString(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
|
OrganizationID: org1.ID,
|
|
ExcludeServiceAccounts: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, got1.UserACL)
|
|
require.Empty(t, got1.GroupACL)
|
|
|
|
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, got2.UserACL)
|
|
})
|
|
|
|
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
regularUser := dbgen.User(t, db, database.User{})
|
|
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
|
sharedUser := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: regularUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: saUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: sharedUser.ID,
|
|
})
|
|
|
|
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: regularUser.ID,
|
|
OrganizationID: org.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: saUser.ID,
|
|
OrganizationID: org.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
ExcludeServiceAccounts: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Regular user workspace ACLs should be cleared.
|
|
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, gotRegular.UserACL)
|
|
|
|
// Service account workspace ACLs should be preserved.
|
|
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
}, gotSA.UserACL)
|
|
})
|
|
}
|
|
|
|
func TestAuthorizedAuditLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var allLogs []database.AuditLog
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
siteWideIDs := []uuid.UUID{uuid.New(), uuid.New()}
|
|
for _, id := range siteWideIDs {
|
|
allLogs = append(allLogs, dbgen.AuditLog(t, db, database.AuditLog{
|
|
ID: id,
|
|
OrganizationID: uuid.Nil,
|
|
}))
|
|
}
|
|
|
|
// This map is a simple way to insert a given number of organizations
|
|
// and audit logs for each organization.
|
|
// map[orgID][]AuditLogID
|
|
orgAuditLogs := map[uuid.UUID][]uuid.UUID{
|
|
uuid.New(): {uuid.New(), uuid.New()},
|
|
uuid.New(): {uuid.New(), uuid.New()},
|
|
}
|
|
orgIDs := make([]uuid.UUID, 0, len(orgAuditLogs))
|
|
for orgID := range orgAuditLogs {
|
|
orgIDs = append(orgIDs, orgID)
|
|
}
|
|
for orgID, ids := range orgAuditLogs {
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
for _, id := range ids {
|
|
allLogs = append(allLogs, dbgen.AuditLog(t, db, database.AuditLog{
|
|
ID: id,
|
|
OrganizationID: orgID,
|
|
}))
|
|
}
|
|
}
|
|
|
|
// Now fetch all the logs
|
|
auditorRole, err := rbac.RoleByName(rbac.RoleAuditor())
|
|
require.NoError(t, err)
|
|
|
|
memberRole, err := rbac.RoleByName(rbac.RoleMember())
|
|
require.NoError(t, err)
|
|
|
|
orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role {
|
|
t.Helper()
|
|
|
|
role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID))
|
|
require.NoError(t, err)
|
|
return role
|
|
}
|
|
|
|
t.Run("NoAccess", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is a member of 0 organizations
|
|
memberCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "member",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{memberRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(memberCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(memberCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: No logs returned and count is 0
|
|
require.Equal(t, int64(0), count, "count should be 0")
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
})
|
|
|
|
t.Run("SiteWideAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A site wide auditor
|
|
siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "owner",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{auditorRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: the auditor queries for audit logs
|
|
count, err := db.CountAuditLogs(siteAuditorCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(siteAuditorCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: All logs are returned and count matches
|
|
require.Equal(t, int64(len(allLogs)), count, "count should match total number of logs")
|
|
require.ElementsMatch(t, auditOnlyIDs(allLogs), auditOnlyIDs(logs), "all logs should be returned")
|
|
})
|
|
|
|
t.Run("SingleOrgAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
orgID := orgIDs[0]
|
|
// Given: An organization scoped auditor
|
|
orgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, orgID)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The auditor queries for audit logs
|
|
count, err := db.CountAuditLogs(orgAuditCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(orgAuditCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: Only the logs for the organization are returned and count matches
|
|
require.Equal(t, int64(len(orgAuditLogs[orgID])), count, "count should match organization logs")
|
|
require.ElementsMatch(t, orgAuditLogs[orgID], auditOnlyIDs(logs), "only organization logs should be returned")
|
|
})
|
|
|
|
t.Run("TwoOrgAuditors", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
first := orgIDs[0]
|
|
second := orgIDs[1]
|
|
// Given: A user who is an auditor for two organizations
|
|
multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(multiOrgAuditCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(multiOrgAuditCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: All logs for both organizations are returned and count matches
|
|
expectedLogs := append([]uuid.UUID{}, orgAuditLogs[first]...)
|
|
expectedLogs = append(expectedLogs, orgAuditLogs[second]...)
|
|
require.Equal(t, int64(len(expectedLogs)), count, "count should match sum of both organizations")
|
|
require.ElementsMatch(t, expectedLogs, auditOnlyIDs(logs), "logs from both organizations should be returned")
|
|
})
|
|
|
|
t.Run("ErroneousOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is an auditor for an organization that has 0 logs
|
|
userCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(userCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(userCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: No logs are returned and count is 0
|
|
require.Equal(t, int64(0), count, "count should be 0")
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
})
|
|
}
|
|
|
|
func auditOnlyIDs[T database.AuditLog | database.GetAuditLogsOffsetRow](logs []T) []uuid.UUID {
|
|
ids := make([]uuid.UUID, 0, len(logs))
|
|
for _, log := range logs {
|
|
switch log := any(log).(type) {
|
|
case database.AuditLog:
|
|
ids = append(ids, log.ID)
|
|
case database.GetAuditLogsOffsetRow:
|
|
ids = append(ids, log.AuditLog.ID)
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var allLogs []database.ConnectionLog
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
authDb := dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgA.Org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
wsID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
WorkspaceTransition: database.WorkspaceTransitionStart,
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: wsID,
|
|
})
|
|
|
|
// This map is a simple way to insert a given number of organizations
|
|
// and audit logs for each organization.
|
|
// map[orgID][]ConnectionLogID
|
|
orgConnectionLogs := map[uuid.UUID][]uuid.UUID{
|
|
orgA.Org.ID: {uuid.New(), uuid.New()},
|
|
orgB.Org.ID: {uuid.New(), uuid.New()},
|
|
}
|
|
orgIDs := make([]uuid.UUID, 0, len(orgConnectionLogs))
|
|
for orgID := range orgConnectionLogs {
|
|
orgIDs = append(orgIDs, orgID)
|
|
}
|
|
for orgID, ids := range orgConnectionLogs {
|
|
for _, id := range ids {
|
|
allLogs = append(allLogs, dbgen.ConnectionLog(t, authDb, database.UpsertConnectionLogParams{
|
|
WorkspaceID: wsID,
|
|
WorkspaceOwnerID: user.ID,
|
|
ID: id,
|
|
OrganizationID: orgID,
|
|
}))
|
|
}
|
|
}
|
|
|
|
// Now fetch all the logs
|
|
auditorRole, err := rbac.RoleByName(rbac.RoleAuditor())
|
|
require.NoError(t, err)
|
|
|
|
memberRole, err := rbac.RoleByName(rbac.RoleMember())
|
|
require.NoError(t, err)
|
|
|
|
orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role {
|
|
t.Helper()
|
|
|
|
role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID))
|
|
require.NoError(t, err)
|
|
return role
|
|
}
|
|
|
|
t.Run("NoAccess", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is a member of 0 organizations
|
|
memberCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "member",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{memberRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(memberCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: No logs returned
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("SiteWideAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A site wide auditor
|
|
siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "owner",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{auditorRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: the auditor queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(siteAuditorCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: All logs are returned
|
|
require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("SingleOrgAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
orgID := orgIDs[0]
|
|
// Given: An organization scoped auditor
|
|
orgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, orgID)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The auditor queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(orgAuditCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: Only the logs for the organization are returned
|
|
require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("TwoOrgAuditors", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
first := orgIDs[0]
|
|
second := orgIDs[1]
|
|
// Given: A user who is an auditor for two organizations
|
|
multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(multiOrgAuditCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: All logs for both organizations are returned
|
|
require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("ErroneousOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is an auditor for an organization that has 0 logs
|
|
userCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
logs, err := authDb.GetConnectionLogsOffset(userCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: No logs are returned
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
}
|
|
|
|
func TestCountConnectionLogs(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
userA := dbgen.User(t, db, database.User{})
|
|
tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID})
|
|
wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID})
|
|
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
userB := dbgen.User(t, db, database.User{})
|
|
tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID})
|
|
wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID})
|
|
|
|
// Create logs for two different orgs.
|
|
for i := 0; i < 20; i++ {
|
|
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
OrganizationID: wsA.OrganizationID,
|
|
WorkspaceOwnerID: wsA.OwnerID,
|
|
WorkspaceID: wsA.ID,
|
|
Type: database.ConnectionTypeSsh,
|
|
})
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
OrganizationID: wsB.OrganizationID,
|
|
WorkspaceOwnerID: wsB.OwnerID,
|
|
WorkspaceID: wsB.ID,
|
|
Type: database.ConnectionTypeSsh,
|
|
})
|
|
}
|
|
|
|
// Count with a filter for orgA.
|
|
countParams := database.CountConnectionLogsParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
}
|
|
totalCount, err := db.CountConnectionLogs(ctx, countParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(20), totalCount)
|
|
|
|
// Get a paginated result for the same filter.
|
|
getParams := database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
LimitOpt: 5,
|
|
OffsetOpt: 10,
|
|
}
|
|
logs, err := db.GetConnectionLogsOffset(ctx, getParams)
|
|
require.NoError(t, err)
|
|
require.Len(t, logs, 5)
|
|
|
|
// The count with the filter should remain the same, independent of pagination.
|
|
countAfterGet, err := db.CountConnectionLogs(ctx, countParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(20), countAfterGet)
|
|
}
|
|
|
|
func TestConnectionLogsOffsetFilters(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
|
|
user1 := dbgen.User(t, db, database.User{
|
|
Username: "user1",
|
|
Email: "user1@test.com",
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{
|
|
Username: "user2",
|
|
Email: "user2@test.com",
|
|
})
|
|
user3 := dbgen.User(t, db, database.User{
|
|
Username: "user3",
|
|
Email: "user3@test.com",
|
|
})
|
|
|
|
ws1Tpl := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: user1.ID})
|
|
ws1 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: orgA.Org.ID,
|
|
TemplateID: ws1Tpl.ID,
|
|
})
|
|
ws2Tpl := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: user2.ID})
|
|
ws2 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: orgB.Org.ID,
|
|
TemplateID: ws2Tpl.ID,
|
|
})
|
|
|
|
now := dbtime.Now()
|
|
log1ConnID := uuid.New()
|
|
log1 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-4 * time.Hour),
|
|
OrganizationID: ws1.OrganizationID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceName: ws1.Name,
|
|
Type: database.ConnectionTypeWorkspaceApp,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user1.ID, Valid: true},
|
|
UserAgent: sql.NullString{String: "Mozilla/5.0", Valid: true},
|
|
SlugOrPort: sql.NullString{String: "code-server", Valid: true},
|
|
ConnectionID: uuid.NullUUID{UUID: log1ConnID, Valid: true},
|
|
})
|
|
|
|
log2ConnID := uuid.New()
|
|
log2 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-3 * time.Hour),
|
|
OrganizationID: ws1.OrganizationID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceName: ws1.Name,
|
|
Type: database.ConnectionTypeVscode,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
ConnectionID: uuid.NullUUID{UUID: log2ConnID, Valid: true},
|
|
})
|
|
|
|
// Mark log2 as disconnected
|
|
log2 = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-2 * time.Hour),
|
|
ConnectionID: log2.ConnectionID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
AgentName: log2.AgentName,
|
|
ConnectionStatus: database.ConnectionStatusDisconnected,
|
|
|
|
OrganizationID: log2.OrganizationID,
|
|
})
|
|
|
|
log3ConnID := uuid.New()
|
|
log3 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-2 * time.Hour),
|
|
OrganizationID: ws2.OrganizationID,
|
|
WorkspaceOwnerID: ws2.OwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
WorkspaceName: ws2.Name,
|
|
Type: database.ConnectionTypeSsh,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user2.ID, Valid: true},
|
|
ConnectionID: uuid.NullUUID{UUID: log3ConnID, Valid: true},
|
|
})
|
|
|
|
// Mark log3 as disconnected
|
|
log3 = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-1 * time.Hour),
|
|
ConnectionID: log3.ConnectionID,
|
|
WorkspaceOwnerID: log3.WorkspaceOwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
AgentName: log3.AgentName,
|
|
ConnectionStatus: database.ConnectionStatusDisconnected,
|
|
|
|
OrganizationID: log3.OrganizationID,
|
|
})
|
|
|
|
log4 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-1 * time.Hour),
|
|
OrganizationID: ws2.OrganizationID,
|
|
WorkspaceOwnerID: ws2.OwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
WorkspaceName: ws2.Name,
|
|
Type: database.ConnectionTypeVscode,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user3.ID, Valid: true},
|
|
})
|
|
|
|
testCases := []struct {
|
|
name string
|
|
params database.GetConnectionLogsOffsetParams
|
|
expectedLogIDs []uuid.UUID
|
|
}{
|
|
{
|
|
name: "NoFilter",
|
|
params: database.GetConnectionLogsOffsetParams{},
|
|
expectedLogIDs: []uuid.UUID{
|
|
log1.ID, log2.ID, log3.ID, log4.ID,
|
|
},
|
|
},
|
|
{
|
|
name: "OrganizationID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgB.Org.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwner",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwner: user1.Username,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwnerID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwnerID: user1.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwnerEmail",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwnerEmail: user2.Email,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "Type",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Type: string(database.ConnectionTypeVscode),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "UserID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
UserID: user1.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "Username",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Username: user1.Username,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "UserEmail",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
UserEmail: user3.Email,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectedAfter",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectedAfter: now.Add(-90 * time.Minute), // 1.5 hours ago
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectedBefore",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectedBefore: now.Add(-150 * time.Minute),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceID: ws2.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectionID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectionID: log1.ConnectionID.UUID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "StatusOngoing",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Status: string(codersdk.ConnectionLogStatusOngoing),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "StatusCompleted",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Status: string(codersdk.ConnectionLogStatusCompleted),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID, log3.ID},
|
|
},
|
|
{
|
|
name: "OrganizationAndTypeAndStatus",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
Type: string(database.ConnectionTypeVscode),
|
|
Status: string(codersdk.ConnectionLogStatusCompleted),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
logs, err := db.GetConnectionLogsOffset(ctx, tc.params)
|
|
require.NoError(t, err)
|
|
count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{
|
|
OrganizationID: tc.params.OrganizationID,
|
|
WorkspaceOwner: tc.params.WorkspaceOwner,
|
|
Type: tc.params.Type,
|
|
UserID: tc.params.UserID,
|
|
Username: tc.params.Username,
|
|
UserEmail: tc.params.UserEmail,
|
|
ConnectedAfter: tc.params.ConnectedAfter,
|
|
ConnectedBefore: tc.params.ConnectedBefore,
|
|
WorkspaceID: tc.params.WorkspaceID,
|
|
ConnectionID: tc.params.ConnectionID,
|
|
Status: tc.params.Status,
|
|
WorkspaceOwnerID: tc.params.WorkspaceOwnerID,
|
|
WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs))
|
|
require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)")
|
|
})
|
|
}
|
|
}
|
|
|
|
func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffsetRow](logs []T) []uuid.UUID {
|
|
ids := make([]uuid.UUID, 0, len(logs))
|
|
for _, log := range logs {
|
|
switch log := any(log).(type) {
|
|
case database.ConnectionLog:
|
|
ids = append(ids, log.ID)
|
|
case database.GetConnectionLogsOffsetRow:
|
|
ids = append(ids, log.ConnectionLog.ID)
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func TestBatchUpsertConnectionLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
|
|
t.Helper()
|
|
u := dbgen.User(t, db, database.User{})
|
|
o := dbgen.Organization(t, db, database.Organization{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: o.ID,
|
|
CreatedBy: u.ID,
|
|
})
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
ID: uuid.New(),
|
|
OwnerID: u.ID,
|
|
OrganizationID: o.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
}
|
|
|
|
// zeroTime is the sentinel value that the SQL treats as "no
|
|
// connect/disconnect time provided".
|
|
zeroTime := time.Time{}
|
|
|
|
defaultIP := pqtype.Inet{
|
|
IPNet: net.IPNet{
|
|
IP: net.IPv4(127, 0, 0, 1),
|
|
Mask: net.IPv4Mask(255, 255, 255, 255),
|
|
},
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("SingleConnect", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime))
|
|
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid,
|
|
"disconnect_time should be NULL for a connect-only event")
|
|
})
|
|
|
|
t.Run("ConnectThenDisconnect", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
// Insert connect.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert disconnect for same connection.
|
|
disconnectTime := connectTime.Add(time.Second)
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{zeroTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{1},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"test disconnect"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
row := rows[0].ConnectionLog
|
|
require.True(t, connectTime.Equal(row.ConnectTime))
|
|
require.True(t, row.DisconnectTime.Valid)
|
|
require.True(t, disconnectTime.Equal(row.DisconnectTime.Time))
|
|
require.Equal(t, "test disconnect", row.DisconnectReason.String)
|
|
require.Equal(t, int32(1), row.Code.Int32)
|
|
})
|
|
|
|
t.Run("DuplicateConnectIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams {
|
|
return database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{ct},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{ip},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
}
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP))
|
|
require.NoError(t, err)
|
|
|
|
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows1, 1)
|
|
|
|
// Second connect with later time and different IP.
|
|
otherIP := pqtype.Inet{
|
|
IPNet: net.IPNet{
|
|
IP: net.IPv4(10, 0, 0, 1),
|
|
Mask: net.IPv4Mask(255, 255, 255, 255),
|
|
},
|
|
Valid: true,
|
|
}
|
|
err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP))
|
|
require.NoError(t, err)
|
|
|
|
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows2, 1)
|
|
|
|
// The LEAST logic should pick the earlier connect_time; IP and
|
|
// other fields are not updated on conflict.
|
|
require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime),
|
|
"connect_time should remain the original (earlier) value")
|
|
})
|
|
|
|
t.Run("OrderIndependentConnectTime", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
connectTime := disconnectTime.Add(-5 * time.Second)
|
|
|
|
// Disconnect arrives first.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"bye"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Connect arrives second with the real (earlier) connect_time.
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime),
|
|
"LEAST should pick the earlier connect_time")
|
|
})
|
|
|
|
t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
|
|
mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams {
|
|
return database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{code},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{reason},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
}
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1))
|
|
require.NoError(t, err)
|
|
|
|
// Second disconnect with different reason and code.
|
|
err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2))
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
row := rows[0].ConnectionLog
|
|
require.Equal(t, "first reason", row.DisconnectReason.String,
|
|
"disconnect_reason should not be overwritten")
|
|
require.Equal(t, int32(1), row.Code.Int32,
|
|
"code should not be overwritten")
|
|
})
|
|
|
|
t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
|
|
// Insert disconnect first.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{42},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"server shutdown"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows1, 1)
|
|
require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid)
|
|
require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String)
|
|
require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32)
|
|
|
|
// Insert connect for same connection_id.
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime.Add(time.Second)},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows2, 1)
|
|
row := rows2[0].ConnectionLog
|
|
require.True(t, row.DisconnectTime.Valid,
|
|
"disconnect_time should not be cleared by a later connect")
|
|
require.Equal(t, "server shutdown", row.DisconnectReason.String,
|
|
"disconnect_reason should not be cleared")
|
|
require.Equal(t, int32(42), row.Code.Int32,
|
|
"code should not be cleared")
|
|
})
|
|
|
|
t.Run("CodeZeroPreserved", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
now := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"normal"},
|
|
DisconnectTime: []time.Time{now},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL")
|
|
require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32,
|
|
"code=0 should be preserved, not treated as NULL")
|
|
})
|
|
|
|
t.Run("CodeNullWhenInvalid", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
now := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{99},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.False(t, rows[0].ConnectionLog.Code.Valid,
|
|
"code should be NULL when code_valid is false")
|
|
})
|
|
|
|
t.Run("NullConnectionIDEvents", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
now := dbtime.Now()
|
|
|
|
// Insert two web events with NULL connection_id (uuid.Nil →
|
|
// NULL via NULLIF) for the same workspace/agent.
|
|
for i := range 2 {
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{200},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{"Mozilla/5.0"},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{"web-terminal"},
|
|
ConnectionID: []uuid.UUID{uuid.Nil},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 2,
|
|
"NULL connection_id rows should not conflict with each other")
|
|
})
|
|
|
|
t.Run("MultipleIndependentConnections", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
now := dbtime.Now()
|
|
|
|
n := 5
|
|
ids := make([]uuid.UUID, n)
|
|
connectTimes := make([]time.Time, n)
|
|
orgIDs := make([]uuid.UUID, n)
|
|
ownerIDs := make([]uuid.UUID, n)
|
|
wsIDs := make([]uuid.UUID, n)
|
|
wsNames := make([]string, n)
|
|
agentNames := make([]string, n)
|
|
types := make([]database.ConnectionType, n)
|
|
codes := make([]int32, n)
|
|
codeValids := make([]bool, n)
|
|
ips := make([]pqtype.Inet, n)
|
|
userAgents := make([]string, n)
|
|
userIDs := make([]uuid.UUID, n)
|
|
slugOrPorts := make([]string, n)
|
|
connIDs := make([]uuid.UUID, n)
|
|
disconnectReasons := make([]string, n)
|
|
disconnectTimes := make([]time.Time, n)
|
|
|
|
for i := range n {
|
|
ids[i] = uuid.New()
|
|
connectTimes[i] = now.Add(time.Duration(i) * time.Second)
|
|
orgIDs[i] = ws.OrganizationID
|
|
ownerIDs[i] = ws.OwnerID
|
|
wsIDs[i] = ws.ID
|
|
wsNames[i] = ws.Name
|
|
agentNames[i] = "agent"
|
|
types[i] = database.ConnectionTypeSsh
|
|
codes[i] = 0
|
|
codeValids[i] = false
|
|
ips[i] = defaultIP
|
|
userAgents[i] = ""
|
|
userIDs[i] = uuid.Nil
|
|
slugOrPorts[i] = ""
|
|
connIDs[i] = uuid.New()
|
|
disconnectReasons[i] = ""
|
|
disconnectTimes[i] = zeroTime
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: ids,
|
|
ConnectTime: connectTimes,
|
|
OrganizationID: orgIDs,
|
|
WorkspaceOwnerID: ownerIDs,
|
|
WorkspaceID: wsIDs,
|
|
WorkspaceName: wsNames,
|
|
AgentName: agentNames,
|
|
Type: types,
|
|
Code: codes,
|
|
CodeValid: codeValids,
|
|
Ip: ips,
|
|
UserAgent: userAgents,
|
|
UserID: userIDs,
|
|
SlugOrPort: slugOrPorts,
|
|
ConnectionID: connIDs,
|
|
DisconnectReason: disconnectReasons,
|
|
DisconnectTime: disconnectTimes,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, n, "each unique connection_id should produce its own row")
|
|
})
|
|
}
|
|
|
|
type tvArgs struct {
|
|
Status database.ProvisionerJobStatus
|
|
// CreateWorkspace is true if we should create a workspace for the template version
|
|
CreateWorkspace bool
|
|
WorkspaceID uuid.UUID
|
|
CreateAgent bool
|
|
WorkspaceTransition database.WorkspaceTransition
|
|
ExtraAgents int
|
|
ExtraBuilds int
|
|
}
|
|
|
|
// createTemplateVersion is a helper function to create a version with its dependencies.
|
|
func createTemplateVersion(t testing.TB, db database.Store, tpl database.Template, args tvArgs) database.TemplateVersion {
|
|
t.Helper()
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tpl.OrganizationID,
|
|
CreatedAt: dbtime.Now(),
|
|
UpdatedAt: dbtime.Now(),
|
|
CreatedBy: tpl.CreatedBy,
|
|
})
|
|
|
|
latestJob := database.ProvisionerJob{
|
|
ID: version.JobID,
|
|
Error: sql.NullString{},
|
|
OrganizationID: tpl.OrganizationID,
|
|
InitiatorID: tpl.CreatedBy,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
if args.CreateWorkspace {
|
|
wrk := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
ID: args.WorkspaceID,
|
|
CreatedAt: time.Time{},
|
|
UpdatedAt: time.Time{},
|
|
OwnerID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
trans := database.WorkspaceTransitionStart
|
|
if args.WorkspaceTransition != "" {
|
|
trans = args.WorkspaceTransition
|
|
}
|
|
latestJob = database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
latestJob = dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
latestResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: latestJob.ID,
|
|
})
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: wrk.ID,
|
|
TemplateVersionID: version.ID,
|
|
BuildNumber: 1,
|
|
Transition: trans,
|
|
InitiatorID: tpl.CreatedBy,
|
|
JobID: latestJob.ID,
|
|
})
|
|
for i := 0; i < args.ExtraBuilds; i++ {
|
|
latestJob = database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
latestJob = dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
latestResource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: latestJob.ID,
|
|
})
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: wrk.ID,
|
|
TemplateVersionID: version.ID,
|
|
// #nosec G115 - Safe conversion as build number is expected to be within int32 range
|
|
BuildNumber: int32(i) + 2,
|
|
Transition: trans,
|
|
InitiatorID: tpl.CreatedBy,
|
|
JobID: latestJob.ID,
|
|
})
|
|
}
|
|
|
|
if args.CreateAgent {
|
|
dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: latestResource.ID,
|
|
})
|
|
}
|
|
for i := 0; i < args.ExtraAgents; i++ {
|
|
dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: latestResource.ID,
|
|
})
|
|
}
|
|
}
|
|
return version
|
|
}
|
|
|
|
func setJobStatus(t testing.TB, status database.ProvisionerJobStatus, j *database.ProvisionerJob) {
|
|
t.Helper()
|
|
|
|
earlier := sql.NullTime{
|
|
Time: dbtime.Now().Add(time.Second * -30),
|
|
Valid: true,
|
|
}
|
|
now := sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
switch status {
|
|
case database.ProvisionerJobStatusRunning:
|
|
j.StartedAt = earlier
|
|
case database.ProvisionerJobStatusPending:
|
|
case database.ProvisionerJobStatusFailed:
|
|
j.StartedAt = earlier
|
|
j.CompletedAt = now
|
|
j.Error = sql.NullString{
|
|
String: "failed",
|
|
Valid: true,
|
|
}
|
|
j.ErrorCode = sql.NullString{
|
|
String: "failed",
|
|
Valid: true,
|
|
}
|
|
case database.ProvisionerJobStatusSucceeded:
|
|
j.StartedAt = earlier
|
|
j.CompletedAt = now
|
|
default:
|
|
t.Fatalf("invalid status: %s", status)
|
|
}
|
|
}
|
|
|
|
func TestArchiveVersions(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Run("ArchiveFailedVersions", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
// Create some versions
|
|
failed := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusFailed,
|
|
CreateWorkspace: false,
|
|
})
|
|
unused := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: false,
|
|
})
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
})
|
|
deleted := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
WorkspaceTransition: database.WorkspaceTransitionDelete,
|
|
})
|
|
|
|
// Now archive failed versions
|
|
archived, err := db.ArchiveUnusedTemplateVersions(ctx, database.ArchiveUnusedTemplateVersionsParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
TemplateID: tpl.ID,
|
|
// All versions
|
|
TemplateVersionID: uuid.Nil,
|
|
JobStatus: database.NullProvisionerJobStatus{
|
|
ProvisionerJobStatus: database.ProvisionerJobStatusFailed,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err, "archive failed versions")
|
|
require.Len(t, archived, 1, "should only archive one version")
|
|
require.Equal(t, failed.ID, archived[0], "should archive failed version")
|
|
|
|
// Archive all unused versions
|
|
archived, err = db.ArchiveUnusedTemplateVersions(ctx, database.ArchiveUnusedTemplateVersionsParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
TemplateID: tpl.ID,
|
|
// All versions
|
|
TemplateVersionID: uuid.Nil,
|
|
})
|
|
require.NoError(t, err, "archive failed versions")
|
|
require.Len(t, archived, 2)
|
|
require.ElementsMatch(t, []uuid.UUID{deleted.ID, unused.ID}, archived, "should archive unused versions")
|
|
})
|
|
}
|
|
|
|
func TestExpectOne(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Run("ErrNoRows", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
_, err = database.ExpectOne(db.GetUsers(ctx, database.GetUsersParams{}))
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
})
|
|
|
|
t.Run("TooMany", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
// Create 2 organizations so the query returns >1
|
|
dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Organizations is an easy table without foreign key dependencies
|
|
_, err = database.ExpectOne(db.GetOrganizations(ctx, database.GetOrganizationsParams{}))
|
|
require.ErrorContains(t, err, "too many rows returned")
|
|
})
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
jobTags []database.StringMap
|
|
daemonTags []database.StringMap
|
|
queueSizes []int64
|
|
queuePositions []int64
|
|
// GetProvisionerJobsByIDsWithQueuePosition takes jobIDs as a parameter.
|
|
// If skipJobIDs is empty, all jobs are passed to the function; otherwise, the specified jobs are skipped.
|
|
// NOTE: Skipping job IDs means they will be excluded from the result,
|
|
// but this should not affect the queue position or queue size of other jobs.
|
|
skipJobIDs map[int]struct{}
|
|
}{
|
|
// Baseline test case
|
|
{
|
|
name: "test-case-1",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
},
|
|
queueSizes: []int64{2, 2, 0},
|
|
queuePositions: []int64{1, 1, 0},
|
|
},
|
|
// Includes an additional provisioner
|
|
{
|
|
name: "test-case-2",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3, 3},
|
|
queuePositions: []int64{1, 1, 3},
|
|
},
|
|
// Skips job at index 0
|
|
{
|
|
name: "test-case-3",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 3},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
},
|
|
},
|
|
// Skips job at index 1
|
|
{
|
|
name: "test-case-4",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 3},
|
|
skipJobIDs: map[int]struct{}{
|
|
1: {},
|
|
},
|
|
},
|
|
// Skips job at index 2
|
|
{
|
|
name: "test-case-5",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 1},
|
|
skipJobIDs: map[int]struct{}{
|
|
2: {},
|
|
},
|
|
},
|
|
// Skips jobs at indexes 0 and 2
|
|
{
|
|
name: "test-case-6",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3},
|
|
queuePositions: []int64{1},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
2: {},
|
|
},
|
|
},
|
|
// Includes two additional jobs that any provisioner can execute.
|
|
{
|
|
name: "test-case-7",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{5, 5, 5, 5, 5},
|
|
queuePositions: []int64{1, 2, 3, 3, 5},
|
|
},
|
|
// Includes two additional jobs that any provisioner can execute, but they are intentionally skipped.
|
|
{
|
|
name: "test-case-8",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{5, 5, 5},
|
|
queuePositions: []int64{3, 3, 5},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
1: {},
|
|
},
|
|
},
|
|
// N jobs (1 job with 0 tags) & 0 provisioners exist
|
|
{
|
|
name: "test-case-9",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{},
|
|
queueSizes: []int64{0, 0, 0},
|
|
queuePositions: []int64{0, 0, 0},
|
|
},
|
|
// N jobs (1 job with 0 tags) & N provisioners
|
|
{
|
|
name: "test-case-10",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
queueSizes: []int64{2, 2, 2},
|
|
queuePositions: []int64{1, 2, 2},
|
|
},
|
|
// (N + 1) jobs (1 job with 0 tags) & N provisioners
|
|
// 1 job not matching any provisioner (first in the list)
|
|
{
|
|
name: "test-case-11",
|
|
jobTags: []database.StringMap{
|
|
{"c": "3"},
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
queueSizes: []int64{0, 2, 2, 2},
|
|
queuePositions: []int64{0, 1, 2, 2},
|
|
},
|
|
// 0 jobs & 0 provisioners
|
|
{
|
|
name: "test-case-12",
|
|
jobTags: []database.StringMap{},
|
|
daemonTags: []database.StringMap{},
|
|
queueSizes: nil, // TODO(yevhenii): should it be empty array instead?
|
|
queuePositions: nil,
|
|
},
|
|
// Many daemons with identical tags should produce same results as one.
|
|
{
|
|
name: "duplicate-daemons-same-tags",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1", "b": "2"},
|
|
},
|
|
queueSizes: []int64{2, 2},
|
|
queuePositions: []int64{1, 2},
|
|
},
|
|
// Jobs that don't match any queried job's daemon should still
|
|
// have correct queue positions.
|
|
{
|
|
name: "irrelevant-daemons-filtered",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"x": "9"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"x": "9"},
|
|
},
|
|
queueSizes: []int64{1},
|
|
queuePositions: []int64{1},
|
|
skipJobIDs: map[int]struct{}{1: {}},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create provisioner jobs based on provided tags:
|
|
allJobs := make([]database.ProvisionerJob, len(tc.jobTags))
|
|
for idx, tags := range tc.jobTags {
|
|
// Make sure jobs are stored in correct order, first job should have the earliest createdAt timestamp.
|
|
// Example for 3 jobs:
|
|
// job_1 createdAt: now - 3 minutes
|
|
// job_2 createdAt: now - 2 minutes
|
|
// job_3 createdAt: now - 1 minute
|
|
timeOffsetInMinutes := len(tc.jobTags) - idx
|
|
timeOffset := time.Duration(timeOffsetInMinutes) * time.Minute
|
|
createdAt := now.Add(-timeOffset)
|
|
|
|
allJobs[idx] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: createdAt,
|
|
Tags: tags,
|
|
})
|
|
}
|
|
|
|
// Create provisioner daemons based on provided tags:
|
|
for idx, tags := range tc.daemonTags {
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: fmt.Sprintf("prov_%v", idx),
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: tags,
|
|
})
|
|
}
|
|
|
|
// Assert invariant: the jobs are in pending status
|
|
for idx, job := range allJobs {
|
|
require.Equal(t, database.ProvisionerJobStatusPending, job.JobStatus, "expected job %d to have status %s", idx, database.ProvisionerJobStatusPending)
|
|
}
|
|
|
|
filteredJobs := make([]database.ProvisionerJob, 0)
|
|
filteredJobIDs := make([]uuid.UUID, 0)
|
|
for idx, job := range allJobs {
|
|
if _, skip := tc.skipJobIDs[idx]; skip {
|
|
continue
|
|
}
|
|
|
|
filteredJobs = append(filteredJobs, job)
|
|
filteredJobIDs = append(filteredJobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: filteredJobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(filteredJobs), "should return all unskipped jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(filteredJobs, func(i, j int) bool {
|
|
return filteredJobs[i].CreatedAt.Before(filteredJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, filteredJobs[idx], job.ProvisionerJob)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, tc.queueSizes, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, tc.queuePositions, queuePositions, "expected queue positions to be set correctly")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_MixedStatuses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create the following provisioner jobs:
|
|
allJobs := []database.ProvisionerJob{
|
|
// Pending. This will be the last in the queue because
|
|
// it was created most recently.
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
// Ensure the `tags` field is NOT NULL for both provisioner jobs and provisioner daemons;
|
|
// otherwise, provisioner daemons won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Another pending. This will come first in the queue
|
|
// because it was created before the previous job.
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-2 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Running
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-3 * time.Minute),
|
|
StartedAt: sql.NullTime{Valid: true, Time: now},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Succeeded
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-4 * time.Minute),
|
|
StartedAt: sql.NullTime{Valid: true, Time: now},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: now},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Canceling
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-5 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: now},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Canceled
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-6 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: now},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: now},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Failed
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-7 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{String: "failed", Valid: true},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
// Assert invariant: the jobs are in the expected order
|
|
require.Len(t, allJobs, 7, "expected 7 jobs")
|
|
for idx, status := range []database.ProvisionerJobStatus{
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusRunning,
|
|
database.ProvisionerJobStatusSucceeded,
|
|
database.ProvisionerJobStatusCanceling,
|
|
database.ProvisionerJobStatusCanceled,
|
|
database.ProvisionerJobStatusFailed,
|
|
} {
|
|
require.Equal(t, status, allJobs[idx].JobStatus, "expected job %d to have status %s", idx, status)
|
|
}
|
|
|
|
var jobIDs []uuid.UUID
|
|
for _, job := range allJobs {
|
|
jobIDs = append(jobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(allJobs, func(i, j int) bool {
|
|
return allJobs[i].CreatedAt.Before(allJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, allJobs[idx], job.ProvisionerJob)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, []int64{0, 0, 0, 0, 0, 2, 2}, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, []int64{0, 0, 0, 0, 0, 1, 2}, queuePositions, "expected queue positions to be set correctly")
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create the following provisioner jobs:
|
|
allJobs := []database.ProvisionerJob{
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-4 * time.Minute),
|
|
// Ensure the `tags` field is NOT NULL for both provisioner jobs and provisioner daemons;
|
|
// otherwise, provisioner daemons won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-5 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-6 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-3 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-2 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-1 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
// Assert invariant: the jobs are in the expected order
|
|
require.Len(t, allJobs, 6, "expected 7 jobs")
|
|
for idx, status := range []database.ProvisionerJobStatus{
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
} {
|
|
require.Equal(t, status, allJobs[idx].JobStatus, "expected job %d to have status %s", idx, status)
|
|
}
|
|
|
|
var jobIDs []uuid.UUID
|
|
for _, job := range allJobs {
|
|
jobIDs = append(jobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(allJobs, func(i, j int) bool {
|
|
return allJobs[i].CreatedAt.Before(allJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, allJobs[idx], job.ProvisionerJob)
|
|
assert.EqualValues(t, allJobs[idx].CreatedAt, job.ProvisionerJob.CreatedAt)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, []int64{6, 6, 6, 6, 6, 6}, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, []int64{1, 2, 3, 4, 5, 6}, queuePositions, "expected queue positions to be set correctly")
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_DuplicateDaemons(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create 3 pending jobs with the same tags.
|
|
jobs := make([]database.ProvisionerJob, 3)
|
|
for i := range jobs {
|
|
jobs[i] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-time.Duration(3-i) * time.Minute),
|
|
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
|
})
|
|
}
|
|
|
|
// Create 50 daemons with identical tags (simulates scale).
|
|
for i := range 50 {
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: fmt.Sprintf("daemon_%d", i),
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
|
})
|
|
}
|
|
|
|
jobIDs := make([]uuid.UUID, len(jobs))
|
|
for i, j := range jobs {
|
|
jobIDs[i] = j.ID
|
|
}
|
|
|
|
results, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx,
|
|
database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 3)
|
|
|
|
// All daemons have identical tags, so queue should be same as
|
|
// if there were just one daemon.
|
|
for i, r := range results {
|
|
assert.Equal(t, int64(3), r.QueueSize, "job %d queue size", i)
|
|
assert.Equal(t, int64(i+1), r.QueuePosition, "job %d queue position", i)
|
|
}
|
|
}
|
|
|
|
func TestGroupRemovalTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbgen.Organization(t, db, database.Organization{})
|
|
_, err := db.InsertAllUsersGroup(context.Background(), orgA.ID)
|
|
require.NoError(t, err)
|
|
|
|
orgB := dbgen.Organization(t, db, database.Organization{})
|
|
_, err = db.InsertAllUsersGroup(context.Background(), orgB.ID)
|
|
require.NoError(t, err)
|
|
|
|
orgs := []database.Organization{orgA, orgB}
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
extra := dbgen.User(t, db, database.User{})
|
|
users := []database.User{user, extra}
|
|
|
|
groupA1 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.ID,
|
|
})
|
|
groupA2 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.ID,
|
|
})
|
|
|
|
groupB1 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgB.ID,
|
|
})
|
|
groupB2 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgB.ID,
|
|
})
|
|
|
|
groups := []database.Group{groupA1, groupA2, groupB1, groupB2}
|
|
|
|
// Add users to all organizations
|
|
for _, u := range users {
|
|
for _, o := range orgs {
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: o.ID,
|
|
UserID: u.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add users to all groups
|
|
for _, u := range users {
|
|
for _, g := range groups {
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
GroupID: g.ID,
|
|
UserID: u.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Verify user is in all groups
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
onlyGroupIDs := func(row database.GetGroupsRow) uuid.UUID {
|
|
return row.Group.ID
|
|
}
|
|
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgA.ID, orgB.ID, // Everyone groups
|
|
groupA1.ID, groupA2.ID, groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(userGroups, onlyGroupIDs))
|
|
|
|
// Remove the user from org A
|
|
err = db.DeleteOrganizationMember(ctx, database.DeleteOrganizationMemberParams{
|
|
OrganizationID: orgA.ID,
|
|
UserID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify user is no longer in org A groups
|
|
userGroups, err = db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgB.ID, // Everyone group
|
|
groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(userGroups, onlyGroupIDs))
|
|
|
|
// Verify extra user is unchanged
|
|
extraUserGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: extra.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgA.ID, orgB.ID, // Everyone groups
|
|
groupA1.ID, groupA2.ID, groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(extraUserGroups, onlyGroupIDs))
|
|
}
|
|
|
|
func TestGetUserStatusCounts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type testCase struct {
|
|
timezone string
|
|
location *time.Location
|
|
reportFrom time.Time
|
|
reportUntil time.Time
|
|
}
|
|
testCases := []testCase{}
|
|
|
|
// GetUserStatusCounts is sensitive to DST transitions, because it generates timestamps exactly
|
|
// one day apart from one another, and specific days can have varying lengths depending on the timezone.
|
|
// Therefore, we test with a variety of timezones.
|
|
timezones := []string{
|
|
"America/St_Johns",
|
|
"Africa/Johannesburg",
|
|
"America/New_York",
|
|
"Europe/London",
|
|
"Asia/Tokyo",
|
|
"Australia/Sydney",
|
|
}
|
|
|
|
// assemble test cases
|
|
for _, tz := range timezones {
|
|
location, err := time.LoadLocation(tz)
|
|
if err != nil {
|
|
t.Fatalf("failed to load location: %v", err)
|
|
}
|
|
|
|
// Testing based on the current system date will flake due to DST transitions.
|
|
// Instead, we test with a fixed range of dates that is large enough to span multiple DST transitions.
|
|
startOfTestDateRange := time.Date(2025, 1, 1, 0, 0, 0, 0, location)
|
|
endOfTestDateRange := time.Date(2026, 1, 1, 0, 0, 0, 0, location)
|
|
// To keep the number of test cases manageable given the large date range,
|
|
// we test with a suitable large interval. This interval is also the length of each report.
|
|
// this ensures we have full coverage of the date range.
|
|
testDateRangeInterval := 60
|
|
|
|
for reportFrom := startOfTestDateRange; !reportFrom.After(endOfTestDateRange); reportFrom = reportFrom.AddDate(0, 0, testDateRangeInterval) {
|
|
testCases = append(testCases, testCase{
|
|
timezone: tz,
|
|
location: location,
|
|
reportFrom: dbtime.Time(reportFrom),
|
|
reportUntil: dbtime.Time(reportFrom.AddDate(0, 0, testDateRangeInterval)),
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(fmt.Sprintf("%s/%s", tc.timezone, tc.reportUntil.Format("2006-01-02T15:04:05Z")), func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
userCreatedAt := tc.reportUntil.AddDate(0, 0, -60)
|
|
firstStatusChange := userCreatedAt.AddDate(0, 0, 29)
|
|
secondStatusChange := firstStatusChange.AddDate(0, 0, 29)
|
|
|
|
t.Run("No Users", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
counts, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: tc.reportFrom,
|
|
EndTime: tc.reportUntil,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, counts, "should return no results when there are no users")
|
|
})
|
|
|
|
t.Run("One User/Creation Only", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
subTestCases := []struct {
|
|
name string
|
|
status database.UserStatus
|
|
}{
|
|
{
|
|
name: "Active Only",
|
|
status: database.UserStatusActive,
|
|
},
|
|
{
|
|
name: "Dormant Only",
|
|
status: database.UserStatusDormant,
|
|
},
|
|
{
|
|
name: "Suspended Only",
|
|
status: database.UserStatusSuspended,
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
dbgen.User(t, db, database.User{
|
|
Status: stc.status,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
startTime := dbtime.StartOfDay(userCreatedAt)
|
|
endTime := dbtime.StartOfDay(tc.reportUntil)
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: startTime,
|
|
EndTime: endTime,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
numDays := 0
|
|
for d := startTime; !d.After(endTime); d = d.AddDate(0, 0, 1) {
|
|
numDays++
|
|
}
|
|
assert.Len(
|
|
t,
|
|
userStatusChanges,
|
|
numDays,
|
|
"should have 1 entry per day between the start and end time, including the end time",
|
|
)
|
|
|
|
for i, row := range userStatusChanges {
|
|
require.Equal(t, stc.status, row.Status, "should have the correct status")
|
|
|
|
rowDate := row.Date.In(tc.location)
|
|
expectedDate := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i)
|
|
assert.True(
|
|
t,
|
|
rowDate.Equal(expectedDate),
|
|
"expected date %s, but got %s for row %n",
|
|
expectedDate.String(),
|
|
rowDate.String(),
|
|
i,
|
|
)
|
|
|
|
if row.Date.Before(userCreatedAt) {
|
|
assert.Equal(t, int64(0), row.Count, "should have 0 users before creation")
|
|
} else {
|
|
assert.Equal(t, int64(1), row.Count, "should have 1 user after creation")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("One User/One Transition", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
subTestCases := []struct {
|
|
name string
|
|
initialStatus database.UserStatus
|
|
targetStatus database.UserStatus
|
|
expectedCounts map[time.Time]map[database.UserStatus]int64
|
|
}{
|
|
{
|
|
name: "Active to Dormant",
|
|
initialStatus: database.UserStatusActive,
|
|
targetStatus: database.UserStatusDormant,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Active to Suspended",
|
|
initialStatus: database.UserStatusActive,
|
|
targetStatus: database.UserStatusSuspended,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant to Active",
|
|
initialStatus: database.UserStatusDormant,
|
|
targetStatus: database.UserStatusActive,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant to Suspended",
|
|
initialStatus: database.UserStatusDormant,
|
|
targetStatus: database.UserStatusSuspended,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended to Active",
|
|
initialStatus: database.UserStatusSuspended,
|
|
targetStatus: database.UserStatusActive,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended to Dormant",
|
|
initialStatus: database.UserStatusSuspended,
|
|
targetStatus: database.UserStatusDormant,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: stc.initialStatus,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
user, err := db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user.ID,
|
|
Status: stc.targetStatus,
|
|
UpdatedAt: firstStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, row := range userStatusChanges {
|
|
rowDate := row.Date.In(tc.location)
|
|
expectedDate := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i/2)
|
|
require.True(
|
|
t,
|
|
rowDate.Equal(expectedDate),
|
|
"expected date %s, but got %s for row %n",
|
|
expectedDate.String(),
|
|
rowDate.String(),
|
|
i,
|
|
)
|
|
switch {
|
|
case row.Date.Before(userCreatedAt):
|
|
require.Equal(t, int64(0), row.Count)
|
|
case row.Date.Before(firstStatusChange):
|
|
if row.Status == stc.initialStatus {
|
|
require.Equal(t, int64(1), row.Count)
|
|
} else if row.Status == stc.targetStatus {
|
|
require.Equal(t, int64(0), row.Count)
|
|
}
|
|
case !row.Date.After(tc.reportUntil):
|
|
if row.Status == stc.initialStatus {
|
|
require.Equal(t, int64(0), row.Count)
|
|
} else if row.Status == stc.targetStatus {
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
default:
|
|
t.Errorf("date %q beyond expected range end %q", row.Date, tc.reportUntil)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Two Users/One Transition", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type transition struct {
|
|
from database.UserStatus
|
|
to database.UserStatus
|
|
}
|
|
|
|
type testCase struct {
|
|
name string
|
|
user1Transition transition
|
|
user2Transition transition
|
|
}
|
|
|
|
subTestCases := []testCase{
|
|
{
|
|
name: "Active->Dormant and Dormant->Suspended",
|
|
user1Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended->Active and Active->Dormant",
|
|
user1Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusActive,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant->Active and Suspended->Dormant",
|
|
user1Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusActive,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
},
|
|
{
|
|
name: "Active->Suspended and Suspended->Active",
|
|
user1Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusActive,
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant->Suspended and Dormant->Active",
|
|
user1Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusActive,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user1 := dbgen.User(t, db, database.User{
|
|
Status: stc.user1Transition.from,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{
|
|
Status: stc.user2Transition.from,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
user1, err := db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user1.ID,
|
|
Status: stc.user1Transition.to,
|
|
UpdatedAt: firstStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
user2, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user2.ID,
|
|
Status: stc.user2Transition.to,
|
|
UpdatedAt: secondStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, userStatusChanges)
|
|
gotCounts := map[time.Time]map[database.UserStatus]int64{}
|
|
for _, row := range userStatusChanges {
|
|
dateInLocation := row.Date.In(tc.location)
|
|
if gotCounts[dateInLocation] == nil {
|
|
gotCounts[dateInLocation] = map[database.UserStatus]int64{}
|
|
}
|
|
gotCounts[dateInLocation][row.Status] = row.Count
|
|
}
|
|
|
|
expectedCounts := map[time.Time]map[database.UserStatus]int64{}
|
|
for d := dbtime.StartOfDay(userCreatedAt); !d.After(dbtime.StartOfDay(tc.reportUntil)); d = d.AddDate(0, 0, 1) {
|
|
expectedCounts[d] = map[database.UserStatus]int64{}
|
|
|
|
// Default values
|
|
expectedCounts[d][stc.user1Transition.from] = 0
|
|
expectedCounts[d][stc.user1Transition.to] = 0
|
|
expectedCounts[d][stc.user2Transition.from] = 0
|
|
expectedCounts[d][stc.user2Transition.to] = 0
|
|
|
|
// Counted Values
|
|
switch {
|
|
case d.Before(userCreatedAt):
|
|
continue
|
|
case d.Before(firstStatusChange):
|
|
expectedCounts[d][stc.user1Transition.from]++
|
|
expectedCounts[d][stc.user2Transition.from]++
|
|
case d.Before(secondStatusChange):
|
|
expectedCounts[d][stc.user1Transition.to]++
|
|
expectedCounts[d][stc.user2Transition.from]++
|
|
case !d.After(tc.reportUntil):
|
|
expectedCounts[d][stc.user1Transition.to]++
|
|
expectedCounts[d][stc.user2Transition.to]++
|
|
default:
|
|
t.Fatalf("date %q beyond expected range end %q", d, tc.reportUntil)
|
|
}
|
|
}
|
|
|
|
require.Equal(t, expectedCounts, gotCounts)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("User precedes and survives query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt.Add(time.Hour * 24)),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, row := range userStatusChanges {
|
|
require.True(
|
|
t,
|
|
row.Date.In(tc.location).Equal(dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, 1+i)),
|
|
"expected date %s, but got %s for row %n",
|
|
dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, 1+i),
|
|
row.Date.In(tc.location).String(),
|
|
i,
|
|
)
|
|
require.Equal(t, database.UserStatusActive, row.Status)
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
})
|
|
|
|
t.Run("User deleted before query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
err := db.UpdateUserDeletedByID(ctx, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = sqlDB.ExecContext(ctx, "UPDATE user_deleted SET deleted_at = $1 WHERE user_id = $2", tc.reportUntil, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: tc.reportUntil.Add(time.Hour * 24),
|
|
EndTime: tc.reportUntil.Add(time.Hour * 48),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, userStatusChanges)
|
|
})
|
|
|
|
t.Run("User deleted during query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
err := db.UpdateUserDeletedByID(ctx, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = sqlDB.ExecContext(ctx, "UPDATE user_deleted SET deleted_at = $1 WHERE user_id = $2", tc.reportUntil, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil.Add(time.Hour * 24)),
|
|
})
|
|
require.NoError(t, err)
|
|
for i, row := range userStatusChanges {
|
|
row.Date = row.Date.In(tc.location)
|
|
userStatusChanges[i] = row
|
|
target := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i)
|
|
assert.True(
|
|
t,
|
|
row.Date.Equal(target),
|
|
"expected date %s, but got %s for row %n",
|
|
target.String(),
|
|
row.Date.String(),
|
|
i,
|
|
)
|
|
require.Equal(t, database.UserStatusActive, row.Status)
|
|
switch {
|
|
case row.Date.Before(userCreatedAt):
|
|
require.Equal(t, int64(0), row.Count)
|
|
case !row.Date.Before(tc.reportUntil):
|
|
// On or after the deletion date, the user should not be counted.
|
|
require.Equal(t, int64(0), row.Count)
|
|
default:
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOrganizationDeleteTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("WorkspaceExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OrganizationID: orgA.Org.ID,
|
|
OwnerID: user.ID,
|
|
}).Do()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 workspaces and 1 templates that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 workspaces")
|
|
require.ErrorContains(t, err, "1 templates")
|
|
})
|
|
|
|
t.Run("TemplateExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgA.Org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 0 workspaces and 1 templates that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "1 templates")
|
|
})
|
|
|
|
t.Run("ProvisionerKeyExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
dbgen.ProvisionerKey(t, db, database.ProvisionerKey{
|
|
OrganizationID: orgA.Org.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 provisioner keys that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "1 provisioner keys")
|
|
})
|
|
|
|
t.Run("GroupExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.Org.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 groups that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 groups")
|
|
})
|
|
|
|
t.Run("MemberExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
userA := dbgen.User(t, db, database.User{})
|
|
userB := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 members that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 members")
|
|
})
|
|
|
|
t.Run("UserDeletedButNotRemovedFromOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
userA := dbgen.User(t, db, database.User{})
|
|
userB := dbgen.User(t, db, database.User{})
|
|
userC := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userC.ID,
|
|
})
|
|
|
|
// Delete one of the users but don't remove them from the org
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
db.UpdateUserDeletedByID(ctx, userB.ID)
|
|
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 members that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 members")
|
|
})
|
|
}
|
|
|
|
type templateVersionWithPreset struct {
|
|
database.TemplateVersion
|
|
preset database.TemplateVersionPreset
|
|
}
|
|
|
|
func createTemplate(t *testing.T, db database.Store, orgID uuid.UUID, userID uuid.UUID) database.Template {
|
|
// create template
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgID,
|
|
CreatedBy: userID,
|
|
ActiveVersionID: uuid.New(),
|
|
})
|
|
|
|
return tmpl
|
|
}
|
|
|
|
type tmplVersionOpts struct {
|
|
DesiredInstances int32
|
|
}
|
|
|
|
func createTmplVersionAndPreset(
|
|
t *testing.T,
|
|
db database.Store,
|
|
tmpl database.Template,
|
|
versionID uuid.UUID,
|
|
now time.Time,
|
|
opts *tmplVersionOpts,
|
|
) templateVersionWithPreset {
|
|
// Create template version with corresponding preset and preset prebuild
|
|
tmplVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
ID: versionID,
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tmpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tmpl.OrganizationID,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
CreatedBy: tmpl.CreatedBy,
|
|
})
|
|
desiredInstances := int32(1)
|
|
if opts != nil {
|
|
desiredInstances = opts.DesiredInstances
|
|
}
|
|
preset := dbgen.Preset(t, db, database.InsertPresetParams{
|
|
TemplateVersionID: tmplVersion.ID,
|
|
Name: "preset",
|
|
DesiredInstances: sql.NullInt32{
|
|
Int32: desiredInstances,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
return templateVersionWithPreset{
|
|
TemplateVersion: tmplVersion,
|
|
preset: preset,
|
|
}
|
|
}
|
|
|
|
type createPrebuiltWorkspaceOpts struct {
|
|
failedJob bool
|
|
createdAt time.Time
|
|
readyAgents int
|
|
notReadyAgents int
|
|
}
|
|
|
|
func createPrebuiltWorkspace(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
tmpl database.Template,
|
|
extTmplVersion templateVersionWithPreset,
|
|
orgID uuid.UUID,
|
|
now time.Time,
|
|
opts *createPrebuiltWorkspaceOpts,
|
|
) {
|
|
// Create job with corresponding resource and agent
|
|
jobError := sql.NullString{}
|
|
if opts != nil && opts.failedJob {
|
|
jobError = sql.NullString{String: "failed", Valid: true}
|
|
}
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: orgID,
|
|
|
|
CreatedAt: now.Add(-1 * time.Minute),
|
|
Error: jobError,
|
|
})
|
|
|
|
// create ready agents
|
|
readyAgents := 0
|
|
if opts != nil {
|
|
readyAgents = opts.readyAgents
|
|
}
|
|
for i := 0; i < readyAgents; i++ {
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// create not ready agents
|
|
notReadyAgents := 1
|
|
if opts != nil {
|
|
notReadyAgents = opts.notReadyAgents
|
|
}
|
|
for i := 0; i < notReadyAgents; i++ {
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Create corresponding workspace and workspace build
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: uuid.MustParse("c42fdf75-3097-471c-8c33-fb52454d81c0"),
|
|
OrganizationID: tmpl.OrganizationID,
|
|
TemplateID: tmpl.ID,
|
|
})
|
|
createdAt := now
|
|
if opts != nil {
|
|
createdAt = opts.createdAt
|
|
}
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
CreatedAt: createdAt,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: extTmplVersion.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: tmpl.CreatedBy,
|
|
JobID: job.ID,
|
|
TemplateVersionPresetID: uuid.NullUUID{
|
|
UUID: extTmplVersion.preset.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
}
|
|
|
|
func TestWorkspacePrebuildsView(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
type workspacePrebuild struct {
|
|
ID uuid.UUID
|
|
Name string
|
|
CreatedAt time.Time
|
|
Ready bool
|
|
CurrentPresetID uuid.UUID
|
|
}
|
|
getWorkspacePrebuilds := func(sqlDB *sql.DB) []*workspacePrebuild {
|
|
rows, err := sqlDB.Query("SELECT id, name, created_at, ready, current_preset_id FROM workspace_prebuilds")
|
|
require.NoError(t, err)
|
|
defer rows.Close()
|
|
|
|
workspacePrebuilds := make([]*workspacePrebuild, 0)
|
|
for rows.Next() {
|
|
var wp workspacePrebuild
|
|
err := rows.Scan(&wp.ID, &wp.Name, &wp.CreatedAt, &wp.Ready, &wp.CurrentPresetID)
|
|
require.NoError(t, err)
|
|
|
|
workspacePrebuilds = append(workspacePrebuilds, &wp)
|
|
}
|
|
|
|
return workspacePrebuilds
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
readyAgents int
|
|
notReadyAgents int
|
|
expectReady bool
|
|
}{
|
|
{
|
|
name: "one ready agent",
|
|
readyAgents: 1,
|
|
notReadyAgents: 0,
|
|
expectReady: true,
|
|
},
|
|
{
|
|
name: "one not ready agent",
|
|
readyAgents: 0,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
{
|
|
name: "one ready, one not ready",
|
|
readyAgents: 1,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
{
|
|
name: "both ready",
|
|
readyAgents: 2,
|
|
notReadyAgents: 0,
|
|
expectReady: true,
|
|
},
|
|
{
|
|
name: "five ready, one not ready",
|
|
readyAgents: 5,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
readyAgents: tc.readyAgents,
|
|
notReadyAgents: tc.notReadyAgents,
|
|
})
|
|
|
|
workspacePrebuilds := getWorkspacePrebuilds(sqlDB)
|
|
require.Len(t, workspacePrebuilds, 1)
|
|
require.Equal(t, tc.expectReady, workspacePrebuilds[0].Ready)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetPresetsBackoff(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
findBackoffByTmplVersionID := func(backoffs []database.GetPresetsBackoffRow, tmplVersionID uuid.UUID) *database.GetPresetsBackoffRow {
|
|
for _, backoff := range backoffs {
|
|
if backoff.TemplateVersionID == tmplVersionID {
|
|
return &backoff
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
t.Run("Single Workspace Build", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Multiple Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV1.preset.ID)
|
|
require.Equal(t, int32(3), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Ignore Inactive Version", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
// Active Version
|
|
tmplV2 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV2.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Multiple Templates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 2)
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl2V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple Templates, Versions and Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3 := createTemplate(t, db, orgID, userID)
|
|
tmpl3V1 := createTmplVersionAndPreset(t, db, tmpl3, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3V2 := createTmplVersionAndPreset(t, db, tmpl3, tmpl3.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 3)
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl2V1.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl3.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl3.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl3V2.preset.ID)
|
|
require.Equal(t, int32(3), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("No Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("No Failed Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("Last job is successful - no backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 1,
|
|
})
|
|
failedJobOpts := createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
}
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &failedJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("Last 3 jobs are successful - no backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("1 job failed out of 3 - backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("3 job failed out of 5 - backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute), // within lookback period - counted as failed job
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute), // within lookback period - counted as failed job
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("check LastBuildAt timestamp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 6,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-0 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(5), backoff.NumFailed)
|
|
// make sure LastBuildAt is equal to latest failed build timestamp
|
|
require.Equal(t, 0, now.Compare(backoff.LastBuildAt))
|
|
}
|
|
})
|
|
|
|
t.Run("failed job outside lookback period", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 1,
|
|
})
|
|
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
require.Len(t, backoffs, 0)
|
|
})
|
|
}
|
|
|
|
func TestGetPresetsAtFailureLimit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
hourBefore := now.Add(-time.Hour)
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
findPresetByTmplVersionID := func(hardLimitedPresets []database.GetPresetsAtFailureLimitRow, tmplVersionID uuid.UUID) *database.GetPresetsAtFailureLimitRow {
|
|
for _, preset := range hardLimitedPresets {
|
|
if preset.TemplateVersionID == tmplVersionID {
|
|
return &preset
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
// true - build is successful
|
|
// false - build is unsuccessful
|
|
buildSuccesses []bool
|
|
hardLimit int64
|
|
expHitHardLimit bool
|
|
}{
|
|
{
|
|
name: "failed build",
|
|
buildSuccesses: []bool{false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "2 failed builds",
|
|
buildSuccesses: []bool{false, false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "successful build",
|
|
buildSuccesses: []bool{true},
|
|
hardLimit: 1,
|
|
expHitHardLimit: false,
|
|
},
|
|
{
|
|
name: "last build is failed",
|
|
buildSuccesses: []bool{true, true, false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "last build is successful",
|
|
buildSuccesses: []bool{false, false, true},
|
|
hardLimit: 1,
|
|
expHitHardLimit: false,
|
|
},
|
|
{
|
|
name: "last 3 builds are failed - hard limit is reached",
|
|
buildSuccesses: []bool{true, true, false, false, false},
|
|
hardLimit: 3,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "1 out of 3 last build is successful - hard limit is NOT reached",
|
|
buildSuccesses: []bool{false, false, true, false, false},
|
|
hardLimit: 3,
|
|
expHitHardLimit: false,
|
|
},
|
|
// hardLimit set to zero, implicitly disables the hard limit.
|
|
{
|
|
name: "despite 5 failed builds, the hard limit is not reached because it's disabled.",
|
|
buildSuccesses: []bool{false, false, false, false, false},
|
|
hardLimit: 0,
|
|
expHitHardLimit: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
for idx, buildSuccess := range tc.buildSuccesses {
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: !buildSuccess,
|
|
createdAt: hourBefore.Add(time.Duration(idx) * time.Second),
|
|
})
|
|
}
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, tc.hardLimit)
|
|
require.NoError(t, err)
|
|
|
|
if !tc.expHitHardLimit {
|
|
require.Len(t, hardLimitedPresets, 0)
|
|
return
|
|
}
|
|
|
|
require.Len(t, hardLimitedPresets, 1)
|
|
hardLimitedPreset := hardLimitedPresets[0]
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmplV1.preset.ID)
|
|
})
|
|
}
|
|
|
|
t.Run("Ignore Inactive Version", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
// Active Version
|
|
tmplV2 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 1)
|
|
hardLimitedPreset := hardLimitedPresets[0]
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmplV2.preset.ID)
|
|
})
|
|
|
|
t.Run("Multiple Templates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 2)
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl1V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl2V1.preset.ID)
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple Templates, Versions and Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3 := createTemplate(t, db, orgID, userID)
|
|
tmpl3V1 := createTmplVersionAndPreset(t, db, tmpl3, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3V2 := createTmplVersionAndPreset(t, db, tmpl3, tmpl3.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimit := int64(2)
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, hardLimit)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 3)
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl1V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl2V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl3.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl3.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl3V2.preset.ID)
|
|
}
|
|
})
|
|
|
|
t.Run("No Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
require.Nil(t, hardLimitedPresets)
|
|
})
|
|
|
|
t.Run("No Failed Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
require.Nil(t, hardLimitedPresets)
|
|
})
|
|
}
|
|
|
|
func TestWorkspaceAgentNameUniqueTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createWorkspaceWithAgent := func(t *testing.T, db database.Store, org database.Organization, agentName string) (database.WorkspaceBuild, database.WorkspaceResource, database.WorkspaceAgent) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
OwnerID: user.ID,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: org.ID,
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
BuildNumber: 1,
|
|
JobID: job.ID,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: build.JobID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
Name: agentName,
|
|
})
|
|
|
|
return build, resource, agent
|
|
}
|
|
|
|
t.Run("DuplicateNamesInSameWorkspaceResource", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with an agent
|
|
_, resource, _ := createWorkspaceWithAgent(t, db, org, "duplicate-agent")
|
|
|
|
// When: Another agent is created for that workspace with the same name.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "duplicate-agent", // Same name as agent1
|
|
ResourceID: resource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "duplicate-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("DuplicateNamesInSameProvisionerJob", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with an agent
|
|
_, resource, agent := createWorkspaceWithAgent(t, db, org, "duplicate-agent")
|
|
|
|
// When: A child agent is created for that workspace with the same name.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: agent.Name,
|
|
ResourceID: resource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "duplicate-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("DuplicateChildNamesOverMultipleResources", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with two agents
|
|
_, resource1, agent1 := createWorkspaceWithAgent(t, db, org, "parent-agent-1")
|
|
|
|
resource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: resource1.JobID})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource2.ID,
|
|
Name: "parent-agent-2",
|
|
})
|
|
|
|
// Given: One agent has a child agent
|
|
agent1Child := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ParentID: uuid.NullUUID{Valid: true, UUID: agent1.ID},
|
|
Name: "child-agent",
|
|
ResourceID: resource1.ID,
|
|
})
|
|
|
|
// When: A child agent is inserted for the other parent.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
ParentID: uuid.NullUUID{Valid: true, UUID: agent2.ID},
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: agent1Child.Name,
|
|
ResourceID: resource2.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "child-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("SameNamesInDifferentWorkspaces", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
agentName := "same-name-different-workspace"
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Given: A workspace with an agent
|
|
_, _, agent1 := createWorkspaceWithAgent(t, db, org, agentName)
|
|
require.Equal(t, agentName, agent1.Name)
|
|
|
|
// When: A second workspace is created with an agent having the same name
|
|
_, _, agent2 := createWorkspaceWithAgent(t, db, org, agentName)
|
|
require.Equal(t, agentName, agent2.Name)
|
|
|
|
// Then: We expect there to be different agents with the same name.
|
|
require.NotEqual(t, agent1.ID, agent2.ID)
|
|
require.Equal(t, agent1.Name, agent2.Name)
|
|
})
|
|
|
|
t.Run("NullWorkspaceID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A resource that does not belong to a workspace build (simulating template import)
|
|
orphanJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
orphanResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: orphanJob.ID,
|
|
})
|
|
|
|
// And this resource has a workspace agent.
|
|
agent1, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "orphan-agent",
|
|
ResourceID: orphanResource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "orphan-agent", agent1.Name)
|
|
|
|
// When: We created another resource that does not belong to a workspace build.
|
|
orphanJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
orphanResource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: orphanJob2.ID,
|
|
})
|
|
|
|
// Then: We expect to be able to create an agent in this new resource that has the same name.
|
|
agent2, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "orphan-agent", // Same name as agent1
|
|
ResourceID: orphanResource2.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "orphan-agent", agent2.Name)
|
|
require.NotEqual(t, agent1.ID, agent2.ID)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentsByParentID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NilParentDoesNotReturnAllParentAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Given: A workspace agent
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// When: We attempt to select agents with a null parent id
|
|
agents, err := db.GetWorkspaceAgentsByParentID(ctx, uuid.Nil)
|
|
require.NoError(t, err)
|
|
|
|
// Then: We expect to see no agents.
|
|
require.Len(t, agents, 0)
|
|
})
|
|
}
|
|
|
|
func setupWorkspaceAgentQueryResources(t *testing.T, db database.Store, count int) []database.WorkspaceResource {
|
|
t.Helper()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
resources := make([]database.WorkspaceResource, 0, count)
|
|
for i := 0; i < count; i++ {
|
|
resources = append(resources, dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
}))
|
|
}
|
|
|
|
return resources
|
|
}
|
|
|
|
func markWorkspaceAgentDeleted(ctx context.Context, t *testing.T, sqlDB *sql.DB, agentID uuid.UUID) {
|
|
t.Helper()
|
|
|
|
_, err := sqlDB.ExecContext(ctx, "UPDATE workspace_agents SET deleted = TRUE WHERE id = $1", agentID)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type workspaceBuildAgentQueryFixture struct {
|
|
Workspace database.WorkspaceTable
|
|
Build database.WorkspaceBuild
|
|
Agent database.WorkspaceAgent
|
|
}
|
|
|
|
func setupWorkspaceBuildAgentQueryWorkspace(t testing.TB, db database.Store, deleted bool) database.WorkspaceTable {
|
|
t.Helper()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
Deleted: deleted,
|
|
})
|
|
}
|
|
|
|
func setupWorkspaceBuildAgentQueryFixture(
|
|
t testing.TB,
|
|
db database.Store,
|
|
authInstanceID string,
|
|
name string,
|
|
createdAt time.Time,
|
|
workspace database.WorkspaceTable,
|
|
) workspaceBuildAgentQueryFixture {
|
|
t.Helper()
|
|
|
|
if workspace.ID == uuid.Nil {
|
|
workspace = setupWorkspaceBuildAgentQueryWorkspace(t, db, false)
|
|
}
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: workspace.TemplateID, Valid: true},
|
|
OrganizationID: workspace.OrganizationID,
|
|
CreatedBy: workspace.OwnerID,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: workspace.OrganizationID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
JobID: job.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
Name: name,
|
|
ResourceID: resource.ID,
|
|
CreatedAt: createdAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
return workspaceBuildAgentQueryFixture{
|
|
Workspace: workspace,
|
|
Build: build,
|
|
Agent: agent,
|
|
}
|
|
}
|
|
|
|
func setupProvisionerJobAgentQueryFixture(
|
|
t testing.TB,
|
|
db database.Store,
|
|
authInstanceID string,
|
|
name string,
|
|
createdAt time.Time,
|
|
jobType database.ProvisionerJobType,
|
|
) database.WorkspaceAgent {
|
|
t.Helper()
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: jobType,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
return dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
Name: name,
|
|
ResourceID: resource.ID,
|
|
CreatedAt: createdAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentsByInstanceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ReturnsAllMatchingRootAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
olderCreatedAt := dbtime.Now().Add(-time.Hour)
|
|
newerCreatedAt := dbtime.Now()
|
|
|
|
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: olderCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: newerCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 2)
|
|
assert.Equal(t, []uuid.UUID{newerAgent.ID, olderAgent.ID}, []uuid.UUID{agents[0].ID, agents[1].ID})
|
|
})
|
|
|
|
t.Run("ExcludesDeletedAndSubAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
baseCreatedAt := dbtime.Now()
|
|
|
|
rootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: baseCreatedAt.Add(-time.Hour),
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ParentID: uuid.NullUUID{UUID: rootAgent.ID, Valid: true},
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: baseCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
deletedRootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: baseCreatedAt.Add(time.Minute),
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
markWorkspaceAgentDeleted(ctx, t, sqlDB, deletedRootAgent.ID)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
assert.Equal(t, rootAgent.ID, agents[0].ID)
|
|
assert.False(t, agents[0].ParentID.Valid)
|
|
})
|
|
|
|
t.Run("OrdersNewestFirst", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
olderCreatedAt := dbtime.Now().Add(-time.Hour)
|
|
newerCreatedAt := dbtime.Now()
|
|
|
|
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: olderCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: newerCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 2)
|
|
assert.Equal(t, newerAgent.ID, agents[0].ID)
|
|
assert.Equal(t, olderAgent.ID, agents[1].ID)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceBuildAgentsByInstanceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ReturnsWorkspaceBuildRootAgentsNewestFirst", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
olderCreatedAt := dbtime.Now().Add(-time.Hour)
|
|
newerCreatedAt := dbtime.Now()
|
|
|
|
older := setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "older", olderCreatedAt, database.WorkspaceTable{})
|
|
newer := setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "newer", newerCreatedAt, database.WorkspaceTable{})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceBuildAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 2)
|
|
assert.Equal(t, []uuid.UUID{newer.Agent.ID, older.Agent.ID}, []uuid.UUID{agents[0].WorkspaceAgent.ID, agents[1].WorkspaceAgent.ID})
|
|
assert.Equal(t, []uuid.UUID{newer.Build.ID, older.Build.ID}, []uuid.UUID{agents[0].WorkspaceBuildID, agents[1].WorkspaceBuildID})
|
|
assert.Equal(t, newer.Workspace.ID, agents[0].WorkspaceTable.ID)
|
|
assert.Equal(t, older.Workspace.ID, agents[1].WorkspaceTable.ID)
|
|
assert.Equal(t, newer.Workspace.OwnerID, agents[0].WorkspaceTable.OwnerID)
|
|
assert.Equal(t, older.Workspace.OwnerID, agents[1].WorkspaceTable.OwnerID)
|
|
assert.Equal(t, newer.Workspace.OrganizationID, agents[0].WorkspaceTable.OrganizationID)
|
|
assert.Equal(t, older.Workspace.OrganizationID, agents[1].WorkspaceTable.OrganizationID)
|
|
assert.False(t, agents[0].WorkspaceTable.Deleted)
|
|
assert.False(t, agents[1].WorkspaceTable.Deleted)
|
|
})
|
|
|
|
t.Run("ExcludesDeletedAgentsSubAgentsAndNonWorkspaceBuildJobs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
baseCreatedAt := dbtime.Now()
|
|
|
|
root := setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "root", baseCreatedAt.Add(-time.Hour), database.WorkspaceTable{})
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ParentID: uuid.NullUUID{UUID: root.Agent.ID, Valid: true},
|
|
Name: "sub",
|
|
ResourceID: root.Agent.ResourceID,
|
|
CreatedAt: baseCreatedAt.Add(time.Minute),
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
deletedAgent := setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "deleted", baseCreatedAt.Add(2*time.Minute), database.WorkspaceTable{})
|
|
_ = setupProvisionerJobAgentQueryFixture(t, db, authInstanceID, "template-import", baseCreatedAt.Add(3*time.Minute), database.ProvisionerJobTypeTemplateVersionImport)
|
|
_ = setupProvisionerJobAgentQueryFixture(t, db, authInstanceID, "dry-run", baseCreatedAt.Add(4*time.Minute), database.ProvisionerJobTypeTemplateVersionDryRun)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
markWorkspaceAgentDeleted(ctx, t, sqlDB, deletedAgent.Agent.ID)
|
|
|
|
agents, err := db.GetWorkspaceBuildAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
assert.Equal(t, root.Agent.ID, agents[0].WorkspaceAgent.ID)
|
|
assert.False(t, agents[0].WorkspaceAgent.ParentID.Valid)
|
|
assert.Equal(t, root.Build.ID, agents[0].WorkspaceBuildID)
|
|
})
|
|
|
|
t.Run("ExcludesDeletedWorkspaces", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
baseCreatedAt := dbtime.Now()
|
|
active := setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "active", baseCreatedAt, database.WorkspaceTable{})
|
|
deletedWorkspace := setupWorkspaceBuildAgentQueryWorkspace(t, db, true)
|
|
_ = setupWorkspaceBuildAgentQueryFixture(t, db, authInstanceID, "deleted-workspace", baseCreatedAt.Add(time.Minute), deletedWorkspace)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceBuildAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
assert.Equal(t, active.Agent.ID, agents[0].WorkspaceAgent.ID)
|
|
assert.Equal(t, active.Workspace.ID, agents[0].WorkspaceTable.ID)
|
|
})
|
|
}
|
|
|
|
func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) {
|
|
t.Helper()
|
|
require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg)
|
|
}
|
|
|
|
// TestGetRunningPrebuiltWorkspaces ensures the correct behavior of the
|
|
// GetRunningPrebuiltWorkspaces query.
|
|
func TestGetRunningPrebuiltWorkspaces(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
|
|
// Given: a prebuilt workspace with a successful start build and a stop build.
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
preset := dbgen.Preset(t, db, database.InsertPresetParams{
|
|
TemplateVersionID: templateVersion.ID,
|
|
DesiredInstances: sql.NullInt32{Int32: 1, Valid: true},
|
|
})
|
|
|
|
setupFixture := func(t *testing.T, db database.Store, name string, deleted bool, transition database.WorkspaceTransition, jobStatus database.ProvisionerJobStatus) database.WorkspaceTable {
|
|
t.Helper()
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: database.PrebuildsSystemUserID,
|
|
TemplateID: template.ID,
|
|
Name: name,
|
|
Deleted: deleted,
|
|
})
|
|
var canceledAt sql.NullTime
|
|
var jobError sql.NullString
|
|
switch jobStatus {
|
|
case database.ProvisionerJobStatusFailed:
|
|
jobError = sql.NullString{String: assert.AnError.Error(), Valid: true}
|
|
case database.ProvisionerJobStatusCanceled:
|
|
canceledAt = sql.NullTime{Time: now, Valid: true}
|
|
}
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StartedAt: sql.NullTime{Time: now.Add(-time.Minute), Valid: true},
|
|
CanceledAt: canceledAt,
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
Error: jobError,
|
|
})
|
|
wb := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
TemplateVersionPresetID: uuid.NullUUID{UUID: preset.ID, Valid: true},
|
|
JobID: pj.ID,
|
|
BuildNumber: 1,
|
|
Transition: transition,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Reason: database.BuildReasonInitiator,
|
|
})
|
|
// Ensure things are set up as expectd
|
|
require.Equal(t, transition, wb.Transition)
|
|
require.Equal(t, int32(1), wb.BuildNumber)
|
|
require.Equal(t, jobStatus, pj.JobStatus)
|
|
require.Equal(t, deleted, ws.Deleted)
|
|
|
|
return ws
|
|
}
|
|
|
|
// Given: a number of prebuild workspaces with different states exist.
|
|
runningPrebuild := setupFixture(t, db, "running-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded)
|
|
_ = setupFixture(t, db, "stopped-prebuild", false, database.WorkspaceTransitionStop, database.ProvisionerJobStatusSucceeded)
|
|
_ = setupFixture(t, db, "failed-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusFailed)
|
|
_ = setupFixture(t, db, "canceled-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusCanceled)
|
|
_ = setupFixture(t, db, "deleted-prebuild", true, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded)
|
|
|
|
// Given: a regular workspace also exists.
|
|
_ = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Name: "test-running-regular-workspace",
|
|
Deleted: false,
|
|
})
|
|
|
|
// When: we query for running prebuild workspaces
|
|
runningPrebuilds, err := db.GetRunningPrebuiltWorkspaces(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Then: only the running prebuild workspace should be returned.
|
|
require.Len(t, runningPrebuilds, 1, "expected only one running prebuilt workspace")
|
|
require.Equal(t, runningPrebuild.ID, runningPrebuilds[0].ID, "expected the running prebuilt workspace to be returned")
|
|
}
|
|
|
|
func TestUserSecretsCRUDOperations(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use raw database without dbauthz wrapper for this test
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
t.Run("FullCRUDWorkflow", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create a new user for this test
|
|
testUser := dbgen.User(t, db, database.User{})
|
|
|
|
// 1. CREATE
|
|
secretID := uuid.New()
|
|
createParams := database.CreateUserSecretParams{
|
|
ID: secretID,
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
Description: "Secret for full CRUD workflow",
|
|
Value: "workflow-value",
|
|
EnvName: "WORKFLOW_ENV",
|
|
FilePath: "/workflow/path",
|
|
}
|
|
|
|
createdSecret, err := db.CreateUserSecret(ctx, createParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, secretID, createdSecret.ID)
|
|
|
|
// 2. READ by UserID and Name
|
|
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
}
|
|
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
|
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
|
|
|
// 3. LIST (metadata only)
|
|
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, secrets, 1)
|
|
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
|
|
|
// 4. LIST with values
|
|
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, secretsWithValues, 1)
|
|
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
|
|
|
// 5. UPDATE (partial - only description)
|
|
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
UpdateDescription: true,
|
|
Description: "Updated workflow description",
|
|
}
|
|
|
|
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
|
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
|
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
|
|
|
// 6. DELETE
|
|
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify deletion
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no rows in result set")
|
|
|
|
// Verify list is empty
|
|
secrets, err = db.ListUserSecrets(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
assert.Len(t, secrets, 0)
|
|
})
|
|
|
|
t.Run("UniqueConstraints", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create a new user for this test
|
|
testUser := dbgen.User(t, db, database.User{})
|
|
|
|
// Create first secret
|
|
secret1 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test",
|
|
Description: "First secret",
|
|
Value: "value1",
|
|
EnvName: "UNIQUE_ENV",
|
|
FilePath: "/unique/path",
|
|
})
|
|
|
|
// Try to create another secret with the same name (should fail)
|
|
_, err := db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test", // Same name
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Try to create another secret with the same env_name (should fail)
|
|
_, err = db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-2",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
EnvName: "UNIQUE_ENV", // Same env_name
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Try to create another secret with the same file_path (should fail)
|
|
_, err = db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-3",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
FilePath: "/unique/path", // Same file_path
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Create secret with empty env_name and file_path (should succeed)
|
|
secret2 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-4",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
EnvName: "", // Empty env_name
|
|
FilePath: "", // Empty file_path
|
|
})
|
|
|
|
// Verify both secrets exist
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID, Name: secret1.Name,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID, Name: secret2.Name,
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// TestUserSecretsSoftDeleteTrigger verifies that a user's secrets
|
|
// are deleted when the user is soft-deleted.
|
|
func TestUserSecretsSoftDeleteTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// userA will be soft-deleted.
|
|
userA := dbgen.User(t, db, database.User{})
|
|
secretA1 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: userA.ID,
|
|
Name: "secret-a-1",
|
|
Value: "value-a-1",
|
|
EnvName: "SECRET_A_1",
|
|
FilePath: "/secrets/a/1",
|
|
})
|
|
secretA2 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: userA.ID,
|
|
Name: "secret-a-2",
|
|
Value: "value-a-2",
|
|
EnvName: "SECRET_A_2",
|
|
FilePath: "/secrets/a/2",
|
|
})
|
|
|
|
// Sanity-check the existing trigger behavior. An API key for
|
|
// userA should also be wiped on soft-delete.
|
|
_, _ = dbgen.APIKey(t, db, database.APIKey{UserID: userA.ID})
|
|
|
|
userB := dbgen.User(t, db, database.User{})
|
|
secretB := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: userB.ID,
|
|
Name: "secret-b",
|
|
Value: "value-b",
|
|
EnvName: "SECRET_B",
|
|
FilePath: "/secrets/b",
|
|
})
|
|
|
|
require.NoError(t, db.UpdateUserDeletedByID(ctx, userA.ID))
|
|
|
|
// userA's secrets are removed after soft-deletion.
|
|
_, err := db.GetUserSecretByID(ctx, secretA1.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
_, err = db.GetUserSecretByID(ctx, secretA2.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
// userA's API key is also removed.
|
|
apiKeysA, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
UserID: userA.ID,
|
|
LoginType: userA.LoginType,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, apiKeysA)
|
|
|
|
// userB's secret is unaffected.
|
|
got, err := db.GetUserSecretByID(ctx, secretB.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, secretB.ID, got.ID)
|
|
|
|
// Trying to insert a new secret for the soft-deleted userA must fail.
|
|
_, err = db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
ID: uuid.New(),
|
|
UserID: userA.ID,
|
|
Name: "post-delete",
|
|
Value: "value",
|
|
EnvName: "POST_DELETE_ENV",
|
|
FilePath: "/secrets/post-delete",
|
|
})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "Cannot create user_secret for deleted user")
|
|
}
|
|
|
|
// TestOrgMembersSoftDeleteTrigger verifies that a user's organization
|
|
// memberships (and transitively their group memberships) are deleted
|
|
// when the user is soft-deleted.
|
|
func TestOrgMembersSoftDeleteTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// SingleOrg verifies the basic case: one org, one group, and a
|
|
// control user whose membership must survive.
|
|
t.Run("SingleOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// userA will be soft-deleted.
|
|
userA := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
|
|
// Add userA to a group in the org (should be cleaned up transitively).
|
|
group := dbgen.Group(t, db, database.Group{OrganizationID: org.ID})
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
UserID: userA.ID,
|
|
GroupID: group.ID,
|
|
})
|
|
|
|
// userB is a control; their membership must not be touched.
|
|
userB := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
UserID: userB.ID,
|
|
GroupID: group.ID,
|
|
})
|
|
|
|
// Soft-delete userA.
|
|
require.NoError(t, db.UpdateUserDeletedByID(ctx, userA.ID))
|
|
|
|
// userA should no longer appear in the organization.
|
|
orgMembers, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
var memberIDs []uuid.UUID
|
|
for _, m := range orgMembers {
|
|
memberIDs = append(memberIDs, m.OrganizationMember.UserID)
|
|
}
|
|
require.NotContains(t, memberIDs, userA.ID)
|
|
require.Contains(t, memberIDs, userB.ID)
|
|
|
|
// The raw org membership rows should also be gone (not just hidden).
|
|
rawOrgs, err := db.GetOrganizationIDsByMemberIDs(ctx, []uuid.UUID{userA.ID})
|
|
require.NoError(t, err)
|
|
require.Empty(t, rawOrgs, "zombie org membership rows should not exist after soft-delete")
|
|
|
|
// userA's group membership should also be removed by the cascading trigger.
|
|
groupMembers, err := db.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams{
|
|
GroupID: group.ID,
|
|
IncludeSystem: true,
|
|
})
|
|
require.NoError(t, err)
|
|
var groupMemberIDs []uuid.UUID
|
|
for _, gm := range groupMembers {
|
|
groupMemberIDs = append(groupMemberIDs, gm.UserID)
|
|
}
|
|
require.NotContains(t, groupMemberIDs, userA.ID)
|
|
require.Contains(t, groupMemberIDs, userB.ID)
|
|
})
|
|
|
|
// MultipleOrgs verifies that memberships are cleaned up across
|
|
// every organization the deleted user belonged to.
|
|
t.Run("MultipleOrgs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org1 := dbgen.Organization(t, db, database.Organization{})
|
|
org2 := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// userA will be soft-deleted. They belong to both orgs.
|
|
userA := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org2.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
|
|
// Add userA to a group in each org.
|
|
group1 := dbgen.Group(t, db, database.Group{OrganizationID: org1.ID})
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
UserID: userA.ID,
|
|
GroupID: group1.ID,
|
|
})
|
|
group2 := dbgen.Group(t, db, database.Group{OrganizationID: org2.ID})
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
UserID: userA.ID,
|
|
GroupID: group2.ID,
|
|
})
|
|
|
|
// userB stays in org1 as a control.
|
|
userB := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
UserID: userB.ID,
|
|
GroupID: group1.ID,
|
|
})
|
|
|
|
// Soft-delete userA.
|
|
require.NoError(t, db.UpdateUserDeletedByID(ctx, userA.ID))
|
|
|
|
// userA should be gone from both orgs.
|
|
for _, org := range []database.Organization{org1, org2} {
|
|
members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, m := range members {
|
|
require.NotEqual(t, userA.ID, m.OrganizationMember.UserID,
|
|
"userA should not appear in org %s", org.ID)
|
|
}
|
|
}
|
|
|
|
// No raw org membership rows should remain.
|
|
rawOrgs, err := db.GetOrganizationIDsByMemberIDs(ctx, []uuid.UUID{userA.ID})
|
|
require.NoError(t, err)
|
|
require.Empty(t, rawOrgs, "zombie org membership rows should not exist after soft-delete")
|
|
|
|
// Group memberships in both orgs should be cleaned up.
|
|
for _, g := range []struct {
|
|
name string
|
|
groupID uuid.UUID
|
|
}{
|
|
{"org1-group", group1.ID},
|
|
{"org2-group", group2.ID},
|
|
} {
|
|
groupMembers, err := db.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams{
|
|
GroupID: g.groupID,
|
|
IncludeSystem: true,
|
|
})
|
|
require.NoError(t, err, g.name)
|
|
for _, gm := range groupMembers {
|
|
require.NotEqual(t, userA.ID, gm.UserID, g.name)
|
|
}
|
|
}
|
|
|
|
// userB's memberships are unaffected.
|
|
org1Members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
|
OrganizationID: org1.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
var org1MemberIDs []uuid.UUID
|
|
for _, m := range org1Members {
|
|
org1MemberIDs = append(org1MemberIDs, m.OrganizationMember.UserID)
|
|
}
|
|
require.Contains(t, org1MemberIDs, userB.ID)
|
|
})
|
|
}
|
|
|
|
func TestUserSecretsAuthorization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use raw database and wrap with dbauthz for authorization testing
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
authDB := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
// Create test users
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
orgAdmin := dbgen.User(t, db, database.User{})
|
|
|
|
// Create organization for org-scoped roles
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Create secrets for users
|
|
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: user1.ID,
|
|
Name: "user1-secret",
|
|
Description: "User 1's secret",
|
|
Value: "user1-value",
|
|
})
|
|
|
|
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: user2.ID,
|
|
Name: "user2-secret",
|
|
Description: "User 2's secret",
|
|
Value: "user2-value",
|
|
})
|
|
|
|
testCases := []struct {
|
|
name string
|
|
subject rbac.Subject
|
|
lookupUserID uuid.UUID
|
|
lookupName string
|
|
expectedAccess bool
|
|
}{
|
|
{
|
|
name: "UserCanAccessOwnSecrets",
|
|
subject: rbac.Subject{
|
|
ID: user1.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: true,
|
|
},
|
|
{
|
|
name: "UserCannotAccessOtherUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: user1.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user2.ID,
|
|
lookupName: "user2-secret",
|
|
expectedAccess: false,
|
|
},
|
|
{
|
|
name: "OwnerCannotAccessUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: owner.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: false,
|
|
},
|
|
{
|
|
name: "OrgAdminCannotAccessUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: orgAdmin.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authCtx := dbauthz.As(ctx, tc.subject)
|
|
|
|
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: tc.lookupUserID,
|
|
Name: tc.lookupName,
|
|
})
|
|
|
|
if tc.expectedAccess {
|
|
require.NoError(t, err, "expected access to be granted")
|
|
} else {
|
|
require.Error(t, err, "expected access to be denied")
|
|
assert.True(t, dbauthz.IsNotAuthorizedError(err), "expected authorization error")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWorkspaceBuildDeadlineConstraint(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Name: "test-workspace",
|
|
Deleted: false,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true},
|
|
CompletedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
JobID: job.ID,
|
|
BuildNumber: 1,
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
deadline time.Time
|
|
maxDeadline time.Time
|
|
expectOK bool
|
|
}{
|
|
{
|
|
name: "no deadline or max_deadline",
|
|
deadline: time.Time{},
|
|
maxDeadline: time.Time{},
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline set when max_deadline is not set",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Time{},
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline before max_deadline",
|
|
deadline: time.Now().Add(-time.Hour),
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline is max_deadline",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: true,
|
|
},
|
|
|
|
{
|
|
name: "deadline after max_deadline",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Now().Add(-time.Hour),
|
|
expectOK: false,
|
|
},
|
|
{
|
|
name: "deadline is not set when max_deadline is set",
|
|
deadline: time.Time{},
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: false,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
err := db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{
|
|
ID: workspaceBuild.ID,
|
|
Deadline: c.deadline,
|
|
MaxDeadline: c.maxDeadline,
|
|
UpdatedAt: time.Now(),
|
|
})
|
|
if c.expectOK {
|
|
require.NoError(t, err)
|
|
} else {
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckWorkspaceBuildsDeadlineBelowMaxDeadline))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWorkspaceACLObjectConstraint(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Deleted: false,
|
|
})
|
|
|
|
t.Run("GroupACLNull", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilACL database.WorkspaceACL
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: nilACL,
|
|
UserACL: database.WorkspaceACL{},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckGroupAclIsObject))
|
|
})
|
|
|
|
t.Run("UserACLNull", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilACL database.WorkspaceACL
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: database.WorkspaceACL{},
|
|
UserACL: nilACL,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUserAclIsObject))
|
|
})
|
|
|
|
t.Run("ValidEmptyObjects", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: database.WorkspaceACL{},
|
|
UserACL: database.WorkspaceACL{},
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// TestGetLatestWorkspaceBuildsByWorkspaceIDs populates the database with
|
|
// workspaces and builds. It then tests that
|
|
// GetLatestWorkspaceBuildsByWorkspaceIDs returns the latest build for some
|
|
// subset of the workspaces.
|
|
func TestGetLatestWorkspaceBuildsByWorkspaceIDs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
admin := dbgen.User(t, db, database.User{})
|
|
|
|
tv := dbfake.TemplateVersion(t, db).
|
|
Seed(database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: admin.ID,
|
|
}).
|
|
Do()
|
|
|
|
users := make([]database.User, 5)
|
|
wrks := make([][]database.WorkspaceTable, len(users))
|
|
exp := make(map[uuid.UUID]database.WorkspaceBuild)
|
|
for i := range users {
|
|
users[i] = dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: users[i].ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
// Each user gets 2 workspaces.
|
|
wrks[i] = make([]database.WorkspaceTable, 2)
|
|
for wi := range wrks[i] {
|
|
wrks[i][wi] = dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
TemplateID: tv.Template.ID,
|
|
OwnerID: users[i].ID,
|
|
})
|
|
|
|
// Choose a deterministic number of builds per workspace
|
|
// No more than 5 builds though, that would be excessive.
|
|
for j := int32(1); int(j) <= (i+wi)%5; j++ {
|
|
wb := dbfake.WorkspaceBuild(t, db, wrks[i][wi]).
|
|
Seed(database.WorkspaceBuild{
|
|
WorkspaceID: wrks[i][wi].ID,
|
|
BuildNumber: j + 1,
|
|
}).
|
|
Do()
|
|
|
|
exp[wrks[i][wi].ID] = wb.Build // Save the final workspace build
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only take half the users. And only take 1 workspace per user for the test.
|
|
// The others are just noice. This just queries a subset of workspaces and builds
|
|
// to make sure the noise doesn't interfere with the results.
|
|
assertWrks := wrks[:len(users)/2]
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ids := slice.Convert[[]database.WorkspaceTable, uuid.UUID](assertWrks, func(pair []database.WorkspaceTable) uuid.UUID {
|
|
return pair[0].ID
|
|
})
|
|
|
|
require.Greater(t, len(ids), 0, "expected some workspace ids for test")
|
|
builds, err := db.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, ids)
|
|
require.NoError(t, err)
|
|
for _, b := range builds {
|
|
expB, ok := exp[b.WorkspaceID]
|
|
require.Truef(t, ok, "unexpected workspace build for workspace id %s", b.WorkspaceID)
|
|
require.Equalf(t, expB.ID, b.ID, "unexpected workspace build id for workspace id %s", b.WorkspaceID)
|
|
require.Equal(t, expB.BuildNumber, b.BuildNumber, "unexpected build number")
|
|
}
|
|
}
|
|
|
|
func TestTasksWithStatusView(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createProvisionerJob := func(t *testing.T, db database.Store, org database.Organization, user database.User, buildStatus database.ProvisionerJobStatus) database.ProvisionerJob {
|
|
t.Helper()
|
|
|
|
var jobParams database.ProvisionerJob
|
|
|
|
switch buildStatus {
|
|
case database.ProvisionerJobStatusPending:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
}
|
|
case database.ProvisionerJobStatusRunning:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusFailed:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
Error: sql.NullString{Valid: true, String: "job failed"},
|
|
}
|
|
case database.ProvisionerJobStatusSucceeded:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusCanceling:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusCanceled:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
default:
|
|
t.Errorf("invalid build status: %v", buildStatus)
|
|
}
|
|
|
|
return dbgen.ProvisionerJob(t, db, nil, jobParams)
|
|
}
|
|
|
|
createTask := func(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
org database.Organization,
|
|
user database.User,
|
|
buildStatus database.ProvisionerJobStatus,
|
|
buildTransition database.WorkspaceTransition,
|
|
agentState database.WorkspaceAgentLifecycleState,
|
|
appHealths []database.WorkspaceAppHealth,
|
|
) database.Task {
|
|
t.Helper()
|
|
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
if buildStatus == "" {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
}
|
|
|
|
job := createProvisionerJob(t, db, org, user, buildStatus)
|
|
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
OwnerID: user.ID,
|
|
})
|
|
workspaceID := uuid.NullUUID{Valid: true, UUID: workspace.ID}
|
|
|
|
task := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
WorkspaceID: workspaceID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
|
|
workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
BuildNumber: 1,
|
|
Transition: buildTransition,
|
|
InitiatorID: user.ID,
|
|
JobID: job.ID,
|
|
})
|
|
workspaceBuildNumber := workspaceBuild.BuildNumber
|
|
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
if agentState != "" {
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
workspaceAgentID := agent.ID
|
|
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{UUID: workspaceAgentID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: agentState,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, health := range appHealths {
|
|
app := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
|
|
AgentID: workspaceAgentID,
|
|
Slug: fmt.Sprintf("test-app-%d", i),
|
|
DisplayName: fmt.Sprintf("Test App %d", i+1),
|
|
Health: health,
|
|
})
|
|
if i == 0 {
|
|
// Assume the first app is the tasks app.
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{UUID: workspaceAgentID, Valid: true},
|
|
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return task
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
buildStatus database.ProvisionerJobStatus
|
|
buildTransition database.WorkspaceTransition
|
|
agentState database.WorkspaceAgentLifecycleState
|
|
appHealths []database.WorkspaceAppHealth
|
|
expectedStatus database.TaskStatus
|
|
description string
|
|
expectBuildNumberValid bool
|
|
expectBuildNumber int32
|
|
expectWorkspaceAgentValid bool
|
|
expectWorkspaceAppValid bool
|
|
}{
|
|
{
|
|
name: "NoWorkspace",
|
|
expectedStatus: "pending",
|
|
description: "Task with no workspace assigned",
|
|
expectBuildNumberValid: false,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "FailedBuild",
|
|
buildStatus: database.ProvisionerJobStatusFailed,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build failed",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "CancelingBuild",
|
|
buildStatus: database.ProvisionerJobStatusCanceling,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build is canceling",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "CanceledBuild",
|
|
buildStatus: database.ProvisionerJobStatusCanceled,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build was canceled",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "StoppedWorkspace",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStop,
|
|
expectedStatus: database.TaskStatusPaused,
|
|
description: "Workspace is stopped",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "DeletedWorkspace",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionDelete,
|
|
expectedStatus: database.TaskStatusPaused,
|
|
description: "Workspace is deleted",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "PendingStart",
|
|
buildStatus: database.ProvisionerJobStatusPending,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusPending,
|
|
description: "Workspace build pending (not yet picked up by provisioner)",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "RunningStart",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace build is starting (running)",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "StartingAgent",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStarting,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace is running but agent is starting",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "CreatedAgent",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateCreated,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace is running but agent is created",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentInitializingApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Agent is ready but app is initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentHealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent is ready and app is healthy",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentDisabledApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthDisabled},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent is ready and app health checking is disabled",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentUnhealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Agent is ready but app is unhealthy",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentStartTimeout",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStartTimeout,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent start timed out but app is healthy, defer to app",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentStartError",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStartError,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent start failed but app is healthy, defer to app",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentShuttingDown",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateShuttingDown,
|
|
expectedStatus: database.TaskStatusUnknown,
|
|
description: "Agent is shutting down",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "AgentOff",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateOff,
|
|
expectedStatus: database.TaskStatusUnknown,
|
|
description: "Agent is off",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentHealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and healthy app should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentInitializingApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Running job with ready agent but initializing app should be initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentUnhealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Running job with ready agent but unhealthy app should be error",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobConnectingAgent",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStarting,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Running job with connecting agent should be initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentDisabledApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthDisabled},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and disabled app health checking should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentHealthyTaskAppUnhealthyOtherAppIsOK",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy, database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and multiple healthy apps should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
task := createTask(ctx, t, db, org, user, tt.buildStatus, tt.buildTransition, tt.agentState, tt.appHealths)
|
|
|
|
got, err := db.GetTaskByID(ctx, task.ID)
|
|
require.NoError(t, err)
|
|
|
|
t.Logf("Task status debug: %s", got.StatusDebug)
|
|
|
|
require.Equal(t, tt.expectedStatus, got.Status)
|
|
|
|
require.Equal(t, tt.expectBuildNumberValid, got.WorkspaceBuildNumber.Valid)
|
|
if tt.expectBuildNumberValid {
|
|
require.Equal(t, tt.expectBuildNumber, got.WorkspaceBuildNumber.Int32)
|
|
}
|
|
|
|
require.Equal(t, tt.expectWorkspaceAgentValid, got.WorkspaceAgentID.Valid)
|
|
if tt.expectWorkspaceAgentValid {
|
|
require.NotEqual(t, uuid.Nil, got.WorkspaceAgentID.UUID)
|
|
}
|
|
|
|
require.Equal(t, tt.expectWorkspaceAppValid, got.WorkspaceAppID.Valid)
|
|
if tt.expectWorkspaceAppValid {
|
|
require.NotEqual(t, uuid.Nil, got.WorkspaceAppID.UUID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetTaskByWorkspaceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupTask func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "task doesn't exist",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "task with no workspace id",
|
|
setupTask: func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable) {
|
|
dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "task with workspace id",
|
|
setupTask: func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable) {
|
|
workspaceID := uuid.NullUUID{Valid: true, UUID: workspace.ID}
|
|
dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
WorkspaceID: workspaceID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
|
|
if tt.setupTask != nil {
|
|
tt.setupTask(t, db, org, user, templateVersion, workspace)
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID)
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.False(t, task.WorkspaceBuildNumber.Valid)
|
|
require.False(t, task.WorkspaceAgentID.Valid)
|
|
require.False(t, task.WorkspaceAppID.Valid)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeleteTaskDeletesTaskSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
task := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
|
|
err := db.UpsertTaskSnapshot(ctx, database.UpsertTaskSnapshotParams{
|
|
TaskID: task.ID,
|
|
LogSnapshot: json.RawMessage(`{"messages":[]}`),
|
|
LogSnapshotCreatedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.DeleteTask(ctx, database.DeleteTaskParams{
|
|
ID: task.ID,
|
|
DeletedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.GetTaskSnapshot(ctx, task.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
}
|
|
|
|
func TestTaskNameUniqueness(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
|
|
taskName := "my-task"
|
|
|
|
// Create initial task for user1.
|
|
task1 := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user1.ID,
|
|
Name: taskName,
|
|
TemplateVersionID: tv.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
require.NotEqual(t, uuid.Nil, task1.ID)
|
|
|
|
tests := []struct {
|
|
name string
|
|
ownerID uuid.UUID
|
|
taskName string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "duplicate task name same user",
|
|
ownerID: user1.ID,
|
|
taskName: taskName,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "duplicate task name different case same user",
|
|
ownerID: user1.ID,
|
|
taskName: "MY-TASK",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "same task name different user",
|
|
ownerID: user2.ID,
|
|
taskName: taskName,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
taskID := uuid.New()
|
|
task, err := db.InsertTask(ctx, database.InsertTaskParams{
|
|
ID: taskID,
|
|
OrganizationID: org.ID,
|
|
OwnerID: tt.ownerID,
|
|
Name: tt.taskName,
|
|
TemplateVersionID: tv.ID,
|
|
TemplateParameters: json.RawMessage("{}"),
|
|
Prompt: "Test prompt",
|
|
CreatedAt: dbtime.Now(),
|
|
})
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, uuid.Nil, task.ID)
|
|
require.NotEqual(t, task1.ID, task.ID)
|
|
require.Equal(t, taskID, task.ID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUsageEventsTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This is not exposed in the querier interface intentionally.
|
|
getDailyRows := func(ctx context.Context, sqlDB *sql.DB) []database.UsageEventsDaily {
|
|
t.Helper()
|
|
rows, err := sqlDB.QueryContext(ctx, "SELECT day, event_type, usage_data FROM usage_events_daily ORDER BY day ASC")
|
|
require.NoError(t, err, "perform query")
|
|
defer rows.Close()
|
|
|
|
var out []database.UsageEventsDaily
|
|
for rows.Next() {
|
|
var row database.UsageEventsDaily
|
|
err := rows.Scan(&row.Day, &row.EventType, &row.UsageData)
|
|
require.NoError(t, err, "scan row")
|
|
out = append(out, row)
|
|
}
|
|
return out
|
|
}
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Assert there are no daily rows.
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 0)
|
|
|
|
// Insert a usage event.
|
|
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "1",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 41}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Assert there is one daily row that contains the correct data.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 41}`, string(rows[0].UsageData))
|
|
// The read row might be `+0000` rather than `UTC` specifically, so just
|
|
// ensure it's within 1 second of the expected time.
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
|
|
// Insert a new usage event on the same UTC day, should increment the count.
|
|
locSydney, err := time.LoadLocation("Australia/Sydney")
|
|
require.NoError(t, err)
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "2",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 1}`),
|
|
// Insert it at a random point during the same day. Sydney is +1000 or
|
|
// +1100, so 8am in Sydney is the previous day in UTC.
|
|
CreatedAt: time.Date(2025, 1, 2, 8, 38, 57, 0, locSydney),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// There should still be only one daily row with the incremented count.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 42}`, string(rows[0].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
|
|
// TODO: when we have a new event type, we should test that adding an
|
|
// event with a different event type on the same day creates a new daily
|
|
// row.
|
|
|
|
// Insert a new usage event on a different day, should create a new daily
|
|
// row.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "3",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 1}`),
|
|
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// There should now be two daily rows.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 2)
|
|
// Output is sorted by day ascending, so the first row should be the
|
|
// previous day's row.
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 42}`, string(rows[0].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[1].EventType)
|
|
require.JSONEq(t, `{"count": 1}`, string(rows[1].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), rows[1].Day, time.Second)
|
|
})
|
|
|
|
t.Run("HeartbeatAISeats", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Insert a heartbeat event.
|
|
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-1",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 10}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "hb_ai_seats_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 10}`, string(rows[0].UsageData))
|
|
|
|
// Insert a higher count on the same day — should take the max.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-2",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 50}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
|
|
// Insert a lower count on the same day — should keep the max (50).
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-3",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 25}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 18, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
|
|
// Insert on a different day.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-4",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 5}`),
|
|
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 2)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
require.JSONEq(t, `{"count": 5}`, string(rows[1].UsageData))
|
|
|
|
// Also insert a dc_managed_agents_v1 on the same first day to
|
|
// verify different event types get separate daily rows.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "dc-1",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 7}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 3)
|
|
})
|
|
|
|
t.Run("UnknownEventType", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Relax the usage_events.event_type check constraint to see what
|
|
// happens when we insert a usage event that the trigger doesn't know
|
|
// about.
|
|
_, err := sqlDB.ExecContext(ctx, "ALTER TABLE usage_events DROP CONSTRAINT usage_event_type_check")
|
|
require.NoError(t, err)
|
|
|
|
// Insert a usage event with an unknown event type.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "broken",
|
|
EventType: "dean's cool event",
|
|
EventData: []byte(`{"my": "cool json"}`),
|
|
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.ErrorContains(t, err, "Unhandled usage event type in aggregate_usage_event")
|
|
|
|
// The event should've been blocked.
|
|
var count int
|
|
err = sqlDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_events WHERE id = 'broken'").Scan(&count)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// We should not have any daily rows.
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 0)
|
|
})
|
|
}
|
|
|
|
func TestListTasks(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
// Given: two organizations and two users, one of which is a member of both
|
|
org1 := dbgen.Organization(t, db, database.Organization{})
|
|
org2 := dbgen.Organization(t, db, database.Organization{})
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: user1.ID,
|
|
})
|
|
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org2.ID,
|
|
UserID: user2.ID,
|
|
})
|
|
|
|
// Given: a template with an active version
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org1.ID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org1.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
|
|
// Helper function to create a task
|
|
createTask := func(orgID, ownerID uuid.UUID) database.Task {
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: orgID,
|
|
OwnerID: ownerID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{})
|
|
sidebarAppID := uuid.New()
|
|
wb := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
JobID: pj.ID,
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: ws.ID,
|
|
})
|
|
wr := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: pj.ID,
|
|
})
|
|
agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: wr.ID,
|
|
})
|
|
wa := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
|
|
ID: sidebarAppID,
|
|
AgentID: agt.ID,
|
|
})
|
|
tsk := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: orgID,
|
|
OwnerID: ownerID,
|
|
Prompt: testutil.GetRandomName(t),
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
})
|
|
_ = dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
|
|
TaskID: tsk.ID,
|
|
WorkspaceBuildNumber: wb.BuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{Valid: true, UUID: agt.ID},
|
|
WorkspaceAppID: uuid.NullUUID{Valid: true, UUID: wa.ID},
|
|
})
|
|
t.Logf("task_id:%s owner_id:%s org_id:%s", tsk.ID, ownerID, orgID)
|
|
return tsk
|
|
}
|
|
|
|
// Given: user1 has one task, user2 has one task, user3 has two tasks (one in each org)
|
|
task1 := createTask(org1.ID, user1.ID)
|
|
task2 := createTask(org1.ID, user2.ID)
|
|
task3 := createTask(org2.ID, user2.ID)
|
|
|
|
// Then: run various filters and assert expected results
|
|
for _, tc := range []struct {
|
|
name string
|
|
filter database.ListTasksParams
|
|
expectIDs []uuid.UUID
|
|
}{
|
|
{
|
|
name: "no filter",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: uuid.Nil,
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
expectIDs: []uuid.UUID{task3.ID, task2.ID, task1.ID},
|
|
},
|
|
{
|
|
name: "filter by user ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
expectIDs: []uuid.UUID{task1.ID},
|
|
},
|
|
{
|
|
name: "filter by organization ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: uuid.Nil,
|
|
OrganizationID: org1.ID,
|
|
},
|
|
expectIDs: []uuid.UUID{task2.ID, task1.ID},
|
|
},
|
|
{
|
|
name: "filter by user and organization ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: org2.ID,
|
|
},
|
|
expectIDs: []uuid.UUID{task3.ID},
|
|
},
|
|
{
|
|
name: "no results",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: org2.ID,
|
|
},
|
|
expectIDs: nil,
|
|
},
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
tasks, err := db.ListTasks(ctx, tc.filter)
|
|
require.NoError(t, err)
|
|
require.Len(t, tasks, len(tc.expectIDs))
|
|
|
|
for idx, eid := range tc.expectIDs {
|
|
task := tasks[idx]
|
|
assert.Equal(t, eid, task.ID, "task ID mismatch at index %d", idx)
|
|
|
|
require.True(t, task.WorkspaceBuildNumber.Valid)
|
|
require.Greater(t, task.WorkspaceBuildNumber.Int32, int32(0))
|
|
require.True(t, task.WorkspaceAgentID.Valid)
|
|
require.NotEqual(t, uuid.Nil, task.WorkspaceAgentID.UUID)
|
|
require.True(t, task.WorkspaceAppID.Valid)
|
|
require.NotEqual(t, uuid.Nil, task.WorkspaceAppID.UUID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateTaskWorkspaceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Create organization, users, template, and template version.
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
// Create another template for mismatch test.
|
|
template2 := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupTask func(t *testing.T) database.Task
|
|
setupWS func(t *testing.T) database.WorkspaceTable
|
|
wantErr bool
|
|
wantNoRow bool
|
|
}{
|
|
{
|
|
name: "successful update with matching template",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: false,
|
|
},
|
|
{
|
|
name: "task already has workspace_id",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
existingWS := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{Valid: true, UUID: existingWS.ID},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true, // No row should be returned because WHERE condition fails.
|
|
},
|
|
{
|
|
name: "template mismatch between task and workspace",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{}, // NULL workspace_id
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template2.ID, // Different template, JOIN will fail.
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true, // No row should be returned because JOIN condition fails.
|
|
},
|
|
{
|
|
name: "task does not exist",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return database.Task{
|
|
ID: uuid.New(), // Non-existent task ID.
|
|
}
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true,
|
|
},
|
|
{
|
|
name: "workspace does not exist",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return database.WorkspaceTable{
|
|
ID: uuid.New(), // Non-existent workspace ID.
|
|
}
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
task := tt.setupTask(t)
|
|
workspace := tt.setupWS(t)
|
|
|
|
updatedTask, err := db.UpdateTaskWorkspaceID(ctx, database.UpdateTaskWorkspaceIDParams{
|
|
ID: task.ID,
|
|
WorkspaceID: uuid.NullUUID{Valid: true, UUID: workspace.ID},
|
|
})
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
return
|
|
}
|
|
|
|
if tt.wantNoRow {
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.Equal(t, task.ID, updatedTask.ID)
|
|
require.True(t, updatedTask.WorkspaceID.Valid)
|
|
require.Equal(t, workspace.ID, updatedTask.WorkspaceID.UUID)
|
|
require.Equal(t, task.OrganizationID, updatedTask.OrganizationID)
|
|
require.Equal(t, task.OwnerID, updatedTask.OwnerID)
|
|
require.Equal(t, task.Name, updatedTask.Name)
|
|
require.Equal(t, task.TemplateVersionID, updatedTask.TemplateVersionID)
|
|
|
|
// Verify the update persisted by fetching the task again.
|
|
fetchedTask, err := db.GetTaskByID(ctx, task.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, fetchedTask.WorkspaceID.Valid)
|
|
require.Equal(t, workspace.ID, fetchedTask.WorkspaceID.UUID)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
t.Run("NonExistingInterception", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: uuid.New(),
|
|
EndedAt: time.Now(),
|
|
})
|
|
require.ErrorContains(t, err, "no rows in result set")
|
|
require.EqualValues(t, database.AIBridgeInterception{}, got)
|
|
})
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
interceptions := []database.AIBridgeInterception{}
|
|
|
|
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
|
insertParams := database.InsertAIBridgeInterceptionParams{
|
|
ID: uid,
|
|
InitiatorID: user.ID,
|
|
Metadata: json.RawMessage("{}"),
|
|
Client: sql.NullString{String: "client", Valid: true},
|
|
CredentialKind: database.CredentialKindCentralized,
|
|
}
|
|
|
|
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uid, intc.ID)
|
|
require.False(t, intc.EndedAt.Valid)
|
|
require.True(t, intc.Client.Valid)
|
|
require.Equal(t, "client", intc.Client.String)
|
|
interceptions = append(interceptions, intc)
|
|
}
|
|
|
|
intc0 := interceptions[0]
|
|
endedAt := time.Now()
|
|
// Mark first interception as done
|
|
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: intc0.ID,
|
|
EndedAt: endedAt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, updated.ID, intc0.ID)
|
|
require.True(t, updated.EndedAt.Valid)
|
|
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
|
|
|
|
// Updating first interception again should fail
|
|
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: intc0.ID,
|
|
EndedAt: endedAt.Add(time.Hour),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
// Other interceptions should not have ended_at set
|
|
for _, intc := range interceptions[1:] {
|
|
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, got.EndedAt.Valid)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDeleteExpiredAPIKeys(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Constant time for testing
|
|
now := time.Date(2025, 11, 20, 12, 0, 0, 0, time.UTC)
|
|
expiredBefore := now.Add(-time.Hour) // Anything before this is expired
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
expiredTimes := []time.Time{
|
|
expiredBefore.Add(-time.Hour * 24 * 365),
|
|
expiredBefore.Add(-time.Hour * 24),
|
|
expiredBefore.Add(-time.Hour),
|
|
expiredBefore.Add(-time.Minute),
|
|
expiredBefore.Add(-time.Second),
|
|
}
|
|
for _, exp := range expiredTimes {
|
|
// Expired api keys
|
|
dbgen.APIKey(t, db, database.APIKey{UserID: user.ID, ExpiresAt: exp})
|
|
}
|
|
|
|
unexpiredTimes := []time.Time{
|
|
expiredBefore.Add(time.Hour * 24 * 365),
|
|
expiredBefore.Add(time.Hour * 24),
|
|
expiredBefore.Add(time.Hour),
|
|
expiredBefore.Add(time.Minute),
|
|
expiredBefore.Add(time.Second),
|
|
}
|
|
for _, unexp := range unexpiredTimes {
|
|
// Unexpired api keys
|
|
dbgen.APIKey(t, db, database.APIKey{UserID: user.ID, ExpiresAt: unexp})
|
|
}
|
|
|
|
// All keys are present before deletion
|
|
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
|
|
|
|
// Delete expired keys
|
|
// First verify the limit works by deleting one at a time
|
|
deletedCount, err := db.DeleteExpiredAPIKeys(ctx, database.DeleteExpiredAPIKeysParams{
|
|
Before: expiredBefore,
|
|
LimitCount: 1,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), deletedCount)
|
|
|
|
// Ensure it was deleted
|
|
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
|
|
|
|
// Delete the rest of the expired keys
|
|
deletedCount, err = db.DeleteExpiredAPIKeys(ctx, database.DeleteExpiredAPIKeysParams{
|
|
Before: expiredBefore,
|
|
LimitCount: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(len(expiredTimes)-1), deletedCount)
|
|
|
|
// Ensure only unexpired keys remain
|
|
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, len(unexpiredTimes))
|
|
}
|
|
|
|
func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
ver := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tpl.OrganizationID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
|
|
t.Run("DuringStopBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with succeeded job (already completed).
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build (becomes latest).
|
|
stopJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should still authenticate during stop build execution.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.NoError(t, err, "agent should authenticate during stop build execution")
|
|
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build, not stop build")
|
|
})
|
|
|
|
t.Run("AfterStopJobCompletes", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with completed job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build (becomes latest) with completed job.
|
|
stopJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &stopJob)
|
|
stopJob = dbgen.ProvisionerJob(t, db, nil, stopJob)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should NOT authenticate after stop job completes.
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate after stop job completes")
|
|
})
|
|
|
|
t.Run("FailedStartBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create START build with FAILED job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusFailed, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create STOP build with running job.
|
|
stopJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should NOT authenticate (start build failed).
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent from failed start build should not authenticate")
|
|
})
|
|
|
|
t.Run("PendingStopBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with succeeded job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build with pending job (not started yet).
|
|
stopJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusPending, &stopJob)
|
|
stopJob = dbgen.ProvisionerJob(t, db, nil, stopJob)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should authenticate during pending stop build.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.NoError(t, err, "agent should authenticate during pending stop build")
|
|
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build")
|
|
})
|
|
|
|
t.Run("MultipleStartStopCycles", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Build 1: START (succeeded).
|
|
startJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob1)
|
|
startJob1 = dbgen.ProvisionerJob(t, db, nil, startJob1)
|
|
startResource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob1.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource1.ID,
|
|
})
|
|
|
|
// Build 2: STOP (succeeded).
|
|
stopJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &stopJob1)
|
|
stopJob1 = dbgen.ProvisionerJob(t, db, nil, stopJob1)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob1.ID,
|
|
})
|
|
|
|
// Build 3: START (succeeded).
|
|
startJob2 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob2)
|
|
startJob2 = dbgen.ProvisionerJob(t, db, nil, startJob2)
|
|
startResource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob2.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild2 := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 3,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob2.ID,
|
|
})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource2.ID,
|
|
})
|
|
|
|
// Build 4: STOP (running).
|
|
stopJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 4,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob2.ID,
|
|
})
|
|
|
|
// Agent from build 3 should authenticate.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent2.AuthToken)
|
|
require.NoError(t, err, "agent from most recent start should authenticate during stop")
|
|
require.Equal(t, agent2.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild2.ID, row.WorkspaceBuild.ID)
|
|
|
|
// Agent from build 1 should NOT authenticate.
|
|
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent from old cycle should not authenticate")
|
|
})
|
|
|
|
t.Run("WrongTransitionType", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create first start build.
|
|
startJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob1)
|
|
startJob1 = dbgen.ProvisionerJob(t, db, nil, startJob1)
|
|
startResource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob1.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource1.ID,
|
|
})
|
|
|
|
// Create another START build as latest (not STOP).
|
|
startJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob2.ID,
|
|
})
|
|
|
|
// Agent from build 1 should NOT authenticate (latest is not STOP).
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate when latest build is not STOP")
|
|
})
|
|
}
|
|
|
|
// Our `InsertWorkspaceAgentDevcontainers` query should ideally be `[]uuid.NullUUID` but unfortunately
|
|
// sqlc infers it as `[]uuid.UUID`. To ensure we don't insert a `uuid.Nil`, the query inserts NULL when
|
|
// passed with `uuid.Nil`. This test ensures we keep this behavior without regression.
|
|
func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
validSubagent []bool
|
|
}{
|
|
{"BothValid", []bool{true, true}},
|
|
{"FirstValidSecondInvalid", []bool{true, false}},
|
|
{"FirstInvalidSecondValid", []bool{false, true}},
|
|
{"BothInvalid", []bool{false, false}},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: job.ID})
|
|
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID})
|
|
)
|
|
|
|
ids := make([]uuid.UUID, len(tc.validSubagent))
|
|
names := make([]string, len(tc.validSubagent))
|
|
workspaceFolders := make([]string, len(tc.validSubagent))
|
|
configPaths := make([]string, len(tc.validSubagent))
|
|
subagentIDs := make([]uuid.UUID, len(tc.validSubagent))
|
|
|
|
for i, valid := range tc.validSubagent {
|
|
ids[i] = uuid.New()
|
|
names[i] = fmt.Sprintf("test-devcontainer-%d", i)
|
|
workspaceFolders[i] = fmt.Sprintf("/workspace%d", i)
|
|
configPaths[i] = fmt.Sprintf("/workspace%d/.devcontainer/devcontainer.json", i)
|
|
|
|
if valid {
|
|
subagentIDs[i] = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
ParentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
|
}).ID
|
|
} else {
|
|
subagentIDs[i] = uuid.Nil
|
|
}
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: We insert multiple devcontainer records.
|
|
devcontainers, err := db.InsertWorkspaceAgentDevcontainers(ctx, database.InsertWorkspaceAgentDevcontainersParams{
|
|
WorkspaceAgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
ID: ids,
|
|
Name: names,
|
|
WorkspaceFolder: workspaceFolders,
|
|
ConfigPath: configPaths,
|
|
SubagentID: subagentIDs,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, devcontainers, len(tc.validSubagent))
|
|
|
|
// Then: Verify each devcontainer has the correct SubagentID validity.
|
|
// - When we pass `uuid.Nil`, we get a `uuid.NullUUID{Valid: false}`
|
|
// - When we pass a valid UUID, we get a `uuid.NullUUID{Valid: true}`
|
|
for i, valid := range tc.validSubagent {
|
|
require.Equal(t, valid, devcontainers[i].SubagentID.Valid, "devcontainer %d: subagent_id validity mismatch", i)
|
|
if valid {
|
|
require.Equal(t, subagentIDs[i], devcontainers[i].SubagentID.UUID, "devcontainer %d: subagent_id UUID mismatch", i)
|
|
}
|
|
}
|
|
|
|
// Perform the same check on data returned by
|
|
// `GetWorkspaceAgentDevcontainersByAgentID` to ensure the fix is at
|
|
// the data storage layer, instead of just at a query level.
|
|
fetched, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, fetched, len(tc.validSubagent))
|
|
|
|
// Sort fetched by name to ensure consistent ordering for comparison.
|
|
slices.SortFunc(fetched, func(a, b database.WorkspaceAgentDevcontainer) int {
|
|
return strings.Compare(a.Name, b.Name)
|
|
})
|
|
|
|
for i, valid := range tc.validSubagent {
|
|
require.Equal(t, valid, fetched[i].SubagentID.Valid, "fetched devcontainer %d: subagent_id validity mismatch", i)
|
|
if valid {
|
|
require.Equal(t, subagentIDs[i], fetched[i].SubagentID.UUID, "fetched devcontainer %d: subagent_id UUID mismatch", i)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInsertChatMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
insertModelConfig := func(
|
|
t *testing.T,
|
|
store database.Store,
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
provider string,
|
|
model string,
|
|
displayName string,
|
|
isDefault bool,
|
|
) database.ChatModelConfig {
|
|
t.Helper()
|
|
|
|
modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: displayName,
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: isDefault,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return modelConfig
|
|
}
|
|
|
|
setupChat := func(t *testing.T) (database.Store, context.Context, database.User, database.Chat, string, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
provider := "openai"
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: provider,
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelConfigA := insertModelConfig(
|
|
t,
|
|
store,
|
|
ctx,
|
|
user.ID,
|
|
provider,
|
|
"test-model-a-"+uuid.NewString(),
|
|
"Test Model A",
|
|
true,
|
|
)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfigA.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return store, ctx, user, chat, provider, modelConfigA
|
|
}
|
|
|
|
insertMessage := func(t *testing.T, store database.Store, ctx context.Context, chatID, userID, modelConfigID uuid.UUID, content string) {
|
|
t.Helper()
|
|
|
|
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{userID},
|
|
ModelConfigID: []uuid.UUID{modelConfigID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
Content: []string{fmt.Sprintf("%q", content)},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
t.Run("ModelSwitchUpdatesLastModelConfigID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, provider, modelConfigA := setupChat(t)
|
|
modelConfigB := insertModelConfig(
|
|
t,
|
|
store,
|
|
ctx,
|
|
user.ID,
|
|
provider,
|
|
"test-model-b-"+uuid.NewString(),
|
|
"Test Model B",
|
|
false,
|
|
)
|
|
|
|
insertMessage(t, store, ctx, chat.ID, user.ID, modelConfigB.ID, "switch models")
|
|
|
|
gotChat, err := store.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigA.ID, chat.LastModelConfigID)
|
|
require.Equal(t, modelConfigB.ID, gotChat.LastModelConfigID)
|
|
})
|
|
|
|
t.Run("SameModelDoesNotBreakAnything", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, _, modelConfigA := setupChat(t)
|
|
|
|
insertMessage(t, store, ctx, chat.ID, user.ID, modelConfigA.ID, "same model")
|
|
|
|
gotChat, err := store.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigA.ID, gotChat.LastModelConfigID)
|
|
})
|
|
|
|
t.Run("BatchInsertMultipleMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, _, modelConfigA := setupChat(t)
|
|
|
|
msgs, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{modelConfigA.ID, modelConfigA.ID, modelConfigA.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
|
Content: []string{`"hello"`, `"response"`, `"tool result"`},
|
|
InputTokens: []int64{10, 0, 0},
|
|
OutputTokens: []int64{0, 20, 0},
|
|
TotalTokens: []int64{10, 20, 0},
|
|
ReasoningTokens: []int64{0, 5, 0},
|
|
CacheCreationTokens: []int64{0, 0, 0},
|
|
CacheReadTokens: []int64{0, 0, 0},
|
|
ContextLimit: []int64{0, 0, 0},
|
|
Compressed: []bool{false, false, false},
|
|
TotalCostMicros: []int64{0, 100, 0},
|
|
RuntimeMs: []int64{0, 500, 0},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, msgs, 3)
|
|
|
|
// Verify ordering and roles.
|
|
require.Equal(t, database.ChatMessageRoleUser, msgs[0].Role)
|
|
require.Equal(t, database.ChatMessageRoleAssistant, msgs[1].Role)
|
|
require.Equal(t, database.ChatMessageRoleTool, msgs[2].Role)
|
|
|
|
// Verify IDs are sequential.
|
|
require.Less(t, msgs[0].ID, msgs[1].ID)
|
|
require.Less(t, msgs[1].ID, msgs[2].ID)
|
|
|
|
// Verify nullable fields: user message has CreatedBy set.
|
|
require.True(t, msgs[0].CreatedBy.Valid)
|
|
require.Equal(t, user.ID, msgs[0].CreatedBy.UUID)
|
|
// Assistant and tool messages have NULL CreatedBy.
|
|
require.False(t, msgs[1].CreatedBy.Valid)
|
|
require.False(t, msgs[2].CreatedBy.Valid)
|
|
|
|
// Verify token fields stored as NULL when zero.
|
|
require.True(t, msgs[0].InputTokens.Valid)
|
|
require.Equal(t, int64(10), msgs[0].InputTokens.Int64)
|
|
require.False(t, msgs[0].OutputTokens.Valid) // 0 → NULL
|
|
require.True(t, msgs[1].OutputTokens.Valid)
|
|
require.Equal(t, int64(20), msgs[1].OutputTokens.Int64)
|
|
|
|
// Verify cost: assistant has cost, others NULL.
|
|
require.True(t, msgs[1].TotalCostMicros.Valid)
|
|
require.Equal(t, int64(100), msgs[1].TotalCostMicros.Int64)
|
|
require.False(t, msgs[0].TotalCostMicros.Valid)
|
|
require.False(t, msgs[2].TotalCostMicros.Valid)
|
|
|
|
// Verify runtime_ms on assistant message.
|
|
require.True(t, msgs[1].RuntimeMs.Valid)
|
|
require.Equal(t, int64(500), msgs[1].RuntimeMs.Int64)
|
|
require.False(t, msgs[0].RuntimeMs.Valid)
|
|
})
|
|
}
|
|
|
|
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This test exercises a complex CTE query for prompt
|
|
// reconstruction after compaction. It requires Postgres.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
// Helper: create a chat model config (required FK for chats).
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
// A chat_providers row is required as a FK for model configs.
|
|
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
newChat := func(t *testing.T) database.Chat {
|
|
t.Helper()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
insertMsg := func(
|
|
t *testing.T,
|
|
chatID uuid.UUID,
|
|
role database.ChatMessageRole,
|
|
vis database.ChatMessageVisibility,
|
|
compressed bool,
|
|
content string,
|
|
) database.ChatMessage {
|
|
t.Helper()
|
|
results, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{uuid.Nil},
|
|
Role: []database.ChatMessageRole{role},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{vis},
|
|
Compressed: []bool{compressed},
|
|
Content: []string{`"` + content + `"`},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
return results[0]
|
|
}
|
|
|
|
msgIDs := func(msgs []database.ChatMessage) []int64 {
|
|
ids := make([]int64, len(msgs))
|
|
for i, m := range msgs {
|
|
ids[i] = m.ID
|
|
}
|
|
return ids
|
|
}
|
|
|
|
t.Run("NoCompaction", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
sys := insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
usr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "hello")
|
|
ast := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "hi there")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got))
|
|
})
|
|
|
|
t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// Messages with visibility=user should NOT appear in the
|
|
// prompt (they are only for the UI).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityUser, false, "user-only msg")
|
|
usr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "hello")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
for _, m := range got {
|
|
require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility,
|
|
"visibility=user messages should not appear in the prompt")
|
|
}
|
|
require.Contains(t, msgIDs(got), usr.ID)
|
|
})
|
|
|
|
t.Run("AfterCompaction", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// Pre-compaction conversation.
|
|
sys := insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
preUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "old question")
|
|
preAsst := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "old answer")
|
|
|
|
// Compaction messages:
|
|
// 1. Summary (role=user, visibility=model, compressed=true).
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "compaction summary")
|
|
// 2. Compressed assistant tool-call (visibility=user).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityUser, true, "tool call")
|
|
// 3. Compressed tool result (visibility=both).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, true, "tool result")
|
|
|
|
// Post-compaction messages.
|
|
postUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "new question")
|
|
postAsst := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "new answer")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
gotIDs := msgIDs(got)
|
|
|
|
// Must include: system prompt, summary, post-compaction.
|
|
require.Contains(t, gotIDs, sys.ID, "system prompt must be included")
|
|
require.Contains(t, gotIDs, summary.ID, "compaction summary must be included")
|
|
require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included")
|
|
require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included")
|
|
|
|
// Must exclude: pre-compaction non-system messages.
|
|
require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded")
|
|
require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded")
|
|
|
|
// Verify ordering.
|
|
require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs)
|
|
})
|
|
|
|
t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// After compaction the summary must appear as role=user so
|
|
// that LLM APIs (e.g. Anthropic) see at least one
|
|
// non-system message in the prompt.
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "summary text")
|
|
newUsr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "new question")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
hasNonSystem := false
|
|
for _, m := range got {
|
|
if m.Role != "system" {
|
|
hasNonSystem = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, hasNonSystem,
|
|
"prompt must contain at least one non-system message after compaction")
|
|
require.Contains(t, msgIDs(got), summary.ID)
|
|
require.Contains(t, msgIDs(got), newUsr.ID)
|
|
})
|
|
|
|
t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// The CTE uses visibility='model' (exact match). If it
|
|
// used IN ('model','both'), the compressed tool result
|
|
// (visibility=both) would be picked as the "summary"
|
|
// instead of the actual summary.
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "real summary")
|
|
compressedTool := insertMsg(t, chat.ID, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, true, "tool result")
|
|
postUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "follow-up")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
gotIDs := msgIDs(got)
|
|
require.Contains(t, gotIDs, summary.ID, "real summary must be included")
|
|
require.NotContains(t, gotIDs, compressedTool.ID,
|
|
"compressed tool result must not be included")
|
|
require.Contains(t, gotIDs, postUser.ID)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
|
CreatedBy: user.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: tmpl.ID,
|
|
OwnerID: user.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: tv.ID,
|
|
JobID: job.ID,
|
|
InitiatorID: user.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
parentReadyAt := dbtime.Now()
|
|
parentStartedAt := parentReadyAt.Add(-time.Second)
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, row.AllAgentsReady)
|
|
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
|
require.Equal(t, "success", row.WorstStatus)
|
|
})
|
|
|
|
t.Run("SubAgentExcluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
|
CreatedBy: user.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: tmpl.ID,
|
|
OwnerID: user.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: tv.ID,
|
|
JobID: job.ID,
|
|
InitiatorID: user.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
parentReadyAt := dbtime.Now()
|
|
parentStartedAt := parentReadyAt.Add(-time.Second)
|
|
parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
// Sub-agent with ready_at 1 hour later should be excluded.
|
|
subAgentReadyAt := parentReadyAt.Add(time.Hour)
|
|
subAgentStartedAt := subAgentReadyAt.Add(-time.Second)
|
|
_ = dbgen.WorkspaceSubAgent(t, db, parentAgent, database.WorkspaceAgent{
|
|
StartedAt: sql.NullTime{Time: subAgentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: subAgentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, row.AllAgentsReady)
|
|
// LastAgentReadyAt should be the parent's, not the sub-agent's.
|
|
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
|
require.Equal(t, "success", row.WorstStatus)
|
|
})
|
|
}
|
|
|
|
// TestUpsertAISeats verifies 'UpsertAISeatState' only returns true when a new
|
|
// row is inserted.
|
|
func TestUpsertAISeats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
now := dbtime.Now()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
newRow, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now.Add(time.Hour * -24),
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, newRow)
|
|
|
|
alreadyExists, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now.Add(time.Hour * -23),
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, alreadyExists)
|
|
|
|
alreadyExists, err = db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now,
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, alreadyExists)
|
|
}
|
|
|
|
func TestGetPRInsights(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
// setupChatInfra creates a fresh database with a user, chat provider,
|
|
// and model config. Returns the store, user ID, model config ID,
|
|
// and org ID.
|
|
setupChatInfra := func(t *testing.T) (database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
|
|
t.Helper()
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "anthropic",
|
|
DisplayName: "Anthropic",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "anthropic",
|
|
Model: "claude-4",
|
|
DisplayName: "Claude 4",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return store, user.ID, mc.ID, org.ID
|
|
}
|
|
|
|
type chatParams struct {
|
|
Store database.Store
|
|
UserID uuid.UUID
|
|
ModelConfigID uuid.UUID
|
|
OrgID uuid.UUID
|
|
}
|
|
|
|
createChat := func(t *testing.T, p chatParams, title string) database.Chat {
|
|
t.Helper()
|
|
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
|
OrganizationID: p.OrgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: p.UserID,
|
|
LastModelConfigID: p.ModelConfigID,
|
|
Title: title,
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
// insertCostMessage inserts a single assistant message with the
|
|
// given total_cost_micros value.
|
|
insertCostMessage := func(t *testing.T, store database.Store, chatID, userID, mcID uuid.UUID, costMicros int64) {
|
|
t.Helper()
|
|
_, err := store.InsertChatMessages(context.Background(), database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{userID},
|
|
ModelConfigID: []uuid.UUID{mcID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
Content: []string{`[{"type":"text","text":"hello"}]`},
|
|
ContentVersion: []int16{1},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{costMicros},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// linkPR associates a chat with a pull request via
|
|
// UpsertChatDiffStatus.
|
|
linkPR := func(t *testing.T, store database.Store, chatID uuid.UUID, prURL, state, title string, additions, deletions, changed int32) {
|
|
t.Helper()
|
|
now := time.Now()
|
|
_, err := store.UpsertChatDiffStatus(context.Background(), database.UpsertChatDiffStatusParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: prURL, Valid: true},
|
|
PullRequestState: sql.NullString{String: state, Valid: true},
|
|
PullRequestTitle: title,
|
|
Additions: additions,
|
|
Deletions: deletions,
|
|
ChangedFiles: changed,
|
|
RefreshedAt: now,
|
|
StaleAt: now.Add(time.Hour),
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
startDate := time.Now().Add(-24 * time.Hour)
|
|
endDate := time.Now().Add(time.Hour)
|
|
noOwner := uuid.NullUUID{}
|
|
|
|
t.Run("MultipleChatsSamePR_CostSummed", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
chatA := createChat(t, p, "chat-A")
|
|
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000) // $5
|
|
|
|
chatB := createChat(t, p, "chat-B")
|
|
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000) // $3
|
|
|
|
prURL := "https://github.com/org/repo/pull/123"
|
|
linkPR(t, store, chatA.ID, prURL, "merged", "fix: something", 100, 20, 5)
|
|
linkPR(t, store, chatB.ID, prURL, "merged", "fix: something", 100, 20, 5)
|
|
|
|
// Both chats reference the same PR. The pr_costs CTE sums
|
|
// cost across all chats for the same PR URL, so the total
|
|
// should be $5 + $3 = $8. The PR itself is counted once.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(8_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("DifferentPRs_NoDuplication", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
chatA := createChat(t, p, "chat-A")
|
|
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, chatA.ID, "https://github.com/org/repo/pull/1", "merged", "feat: A", 50, 10, 2)
|
|
|
|
chatB := createChat(t, p, "chat-B")
|
|
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, chatB.ID, "https://github.com/org/repo/pull/2", "open", "feat: B", 80, 30, 4)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros) // $5 + $3
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
|
|
// RecentPRs ordered by created_at DESC: chatB is newer.
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// Costs must not be mixed across different PRs.
|
|
assert.Equal(t, int64(3_000_000), recent[0].CostMicros) // PR 2 (newer)
|
|
assert.Equal(t, int64(5_000_000), recent[1].CostMicros) // PR 1 (older)
|
|
})
|
|
|
|
// createChildChat creates a chat with ParentChatID and RootChatID
|
|
// set, simulating a subagent/child chat in a tree.
|
|
createChildChat := func(t *testing.T, p chatParams, parentID, rootID uuid.UUID, title string) database.Chat {
|
|
t.Helper()
|
|
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
|
OrganizationID: p.OrgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: p.UserID,
|
|
LastModelConfigID: p.ModelConfigID,
|
|
Title: title,
|
|
ParentChatID: uuid.NullUUID{UUID: parentID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: rootID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
t.Run("DuplicatePRUrl_CountedOnce", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
prURL := "https://github.com/org/repo/pull/99"
|
|
for i := range 3 {
|
|
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
|
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
|
linkPR(t, store, chat.ID, prURL, "merged", "fix: same PR", 40, 10, 3)
|
|
}
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
})
|
|
|
|
t.Run("ChildChatCostsIncluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent chat with a $5 cost.
|
|
parent := createChat(t, p, "parent-chat")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 5_000_000)
|
|
|
|
// Two child chats (subagents) with $2 each. Only the parent
|
|
// has a chat_diff_statuses entry, but the children's costs
|
|
// should be included via the tree join.
|
|
child1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, child1.ID, userID, mcID, 2_000_000)
|
|
|
|
child2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, child2.ID, userID, mcID, 2_000_000)
|
|
|
|
prURL := "https://github.com/org/repo/pull/42"
|
|
linkPR(t, store, parent.ID, prURL, "merged", "feat: tree cost", 60, 15, 3)
|
|
|
|
// Summary should reflect $5 + $2 + $2 = $9 total.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
|
|
|
|
// RecentPRs should return 1 row with the full tree cost.
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(9_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("SiblingPRs_NoCrossContamination", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent chat with $10 orchestration cost.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
|
|
// Child C1 ($5) creates PR1.
|
|
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/10", "merged", "feat: PR1", 50, 10, 2)
|
|
|
|
// Child C2 ($3) creates PR2.
|
|
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, c2.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, c2.ID, "https://github.com/org/repo/pull/11", "open", "feat: PR2", 30, 5, 1)
|
|
|
|
// With direct-branch attribution:
|
|
// PR1 cost = C1's own cost = $5 (parent NOT included — only children of C1)
|
|
// PR2 cost = C2's own cost = $3
|
|
// Total = $8 (no double-counting of parent or siblings)
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// PR2 (newer) = $3, PR1 (older) = $5.
|
|
assert.Equal(t, int64(3_000_000), recent[0].CostMicros)
|
|
assert.Equal(t, int64(5_000_000), recent[1].CostMicros)
|
|
})
|
|
|
|
t.Run("ParentAndChildDifferentPRs_NoCrossContamination", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent P ($10) creates PR1.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
linkPR(t, store, parent.ID, "https://github.com/org/repo/pull/20", "merged", "feat: parent PR", 80, 20, 4)
|
|
|
|
// Child C1 ($5) has its own PR2. Because C1 has its own
|
|
// chat_diff_statuses entry, its cost should NOT be included
|
|
// under PR1 — it belongs to PR2 only.
|
|
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/21", "open", "feat: child PR", 30, 5, 1)
|
|
|
|
// Child C2 ($2) has NO cds entry — pure subagent.
|
|
// Its cost should be included under PR1 (the parent's PR).
|
|
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, c2.ID, userID, mcID, 2_000_000)
|
|
|
|
// PR1 cost = parent ($10) + C2 ($2) = $12 (C1 excluded)
|
|
// PR2 cost = C1 ($5)
|
|
// Total = $17 (actual spend: $10 + $5 + $2 = $17)
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// PR2/C1 (newer) = $5, PR1/parent (older) = $12.
|
|
assert.Equal(t, int64(5_000_000), recent[0].CostMicros)
|
|
assert.Equal(t, int64(12_000_000), recent[1].CostMicros)
|
|
})
|
|
|
|
t.Run("EmptyURLNotCollapsed", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Two chats with empty-string URLs should be treated as
|
|
// separate PRs (NULLIF converts '' to NULL, falling back
|
|
// to c.id::text).
|
|
chatX := createChat(t, p, "chat-X")
|
|
insertCostMessage(t, store, chatX.ID, userID, mcID, 4_000_000)
|
|
linkPR(t, store, chatX.ID, "", "open", "draft: X", 10, 2, 1)
|
|
|
|
chatY := createChat(t, p, "chat-Y")
|
|
insertCostMessage(t, store, chatY.ID, userID, mcID, 6_000_000)
|
|
linkPR(t, store, chatY.ID, "", "merged", "draft: Y", 20, 5, 2)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
})
|
|
|
|
t.Run("ParentAndChildSameURL_DedupedWithCombinedCost", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent P ($10) links to a PR.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
|
|
// Child C ($5) also links to the same PR URL.
|
|
child := createChildChat(t, p, parent.ID, parent.ID, "child")
|
|
insertCostMessage(t, store, child.ID, userID, mcID, 5_000_000)
|
|
|
|
prURL := "https://github.com/org/repo/pull/50"
|
|
linkPR(t, store, parent.ID, prURL, "merged", "feat: shared PR", 70, 15, 3)
|
|
linkPR(t, store, child.ID, prURL, "merged", "feat: shared PR", 70, 15, 3)
|
|
|
|
// Both parent and child have cds entries for the same URL.
|
|
// The PR should be counted once with combined cost $10 + $5 = $15.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(15_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("ZeroCostChat_StillCounted", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// A chat linked to a PR but with NO chat_messages at all.
|
|
// The PR should still appear with zero cost.
|
|
chat := createChat(t, p, "zero-cost-chat")
|
|
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/60", "open", "feat: no messages", 25, 5, 2)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(0), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(0), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("BlankDisplayNameFallsBackToModel", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, _, orgID := setupChatInfra(t)
|
|
|
|
const modelName = "claude-4.1"
|
|
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
|
|
Provider: "anthropic",
|
|
Model: modelName,
|
|
DisplayName: "",
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: false,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: emptyDisplayModel.ID, OrgID: orgID}
|
|
chat := createChat(t, p, "chat-empty-display-name")
|
|
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
|
|
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
|
|
|
|
byModel, err := store.GetPRInsightsPerModel(context.Background(), database.GetPRInsightsPerModelParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, byModel, 1)
|
|
assert.Equal(t, modelName, byModel[0].DisplayName)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, modelName, recent[0].ModelDisplayName)
|
|
})
|
|
|
|
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Merged PR with $5 cost.
|
|
chatMerged := createChat(t, p, "chat-merged")
|
|
insertCostMessage(t, store, chatMerged.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, chatMerged.ID, "https://github.com/org/repo/pull/70", "merged", "fix: merged", 40, 10, 2)
|
|
|
|
// Open PR with $3 cost.
|
|
chatOpen := createChat(t, p, "chat-open")
|
|
insertCostMessage(t, store, chatOpen.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, chatOpen.ID, "https://github.com/org/repo/pull/71", "open", "feat: open", 20, 5, 1)
|
|
|
|
// TotalCostMicros includes both ($5 + $3 = $8), but
|
|
// MergedCostMicros only includes the merged PR ($5).
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
|
})
|
|
|
|
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
|
// verify all are returned.
|
|
const prCount = 25
|
|
for i := range prCount {
|
|
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
|
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
|
linkPR(t, store, chat.ID,
|
|
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
|
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
|
|
}
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
|
|
})
|
|
}
|
|
|
|
func TestChatPinOrderQueries(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
setup := func(t *testing.T) (context.Context, database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
|
|
t.Helper()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
// Use background context for fixture setup so the
|
|
// timed test context doesn't tick during DB init.
|
|
bg := context.Background()
|
|
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
return ctx, db, owner.ID, modelCfg.ID, org.ID
|
|
}
|
|
|
|
createChat := func(t *testing.T, ctx context.Context, db database.Store, ownerID, modelCfgID, orgID uuid.UUID, title string) database.Chat {
|
|
t.Helper()
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: orgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelCfgID,
|
|
Title: title,
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
requirePinOrders := func(t *testing.T, ctx context.Context, db database.Store, want map[uuid.UUID]int32) {
|
|
t.Helper()
|
|
|
|
for chatID, wantPinOrder := range want {
|
|
chat, err := db.GetChatByID(ctx, chatID)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, wantPinOrder, chat.PinOrder)
|
|
}
|
|
}
|
|
|
|
t.Run("PinChatByIDAppendsWithinOwner", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
otherOwner := dbgen.User(t, db, database.User{})
|
|
other := createChat(t, ctx, db, otherOwner.ID, modelCfgID, orgID, "other-owner")
|
|
|
|
require.NoError(t, db.PinChatByID(ctx, other.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, first.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, second.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, third.ID))
|
|
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 2,
|
|
third.ID: 3,
|
|
other.ID: 1,
|
|
})
|
|
})
|
|
|
|
t.Run("UpdateChatPinOrderShiftsNeighborsAndClamps", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 1,
|
|
}))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 2,
|
|
second.ID: 3,
|
|
third.ID: 1,
|
|
})
|
|
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 99,
|
|
}))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 2,
|
|
third.ID: 3,
|
|
})
|
|
})
|
|
|
|
t.Run("UnpinChatByIDCompactsPinnedChats", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
require.NoError(t, db.UnpinChatByID(ctx, second.ID))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 0,
|
|
third.ID: 2,
|
|
})
|
|
})
|
|
|
|
t.Run("ArchiveClearsPinAndExcludesFromRanking", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
// Archive the middle pin.
|
|
_, err := db.ArchiveChatByID(ctx, second.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Archived chat should have pin_order cleared. Remaining
|
|
// pins keep their original positions; the next mutation
|
|
// compacts via ROW_NUMBER().
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 0,
|
|
third.ID: 3,
|
|
})
|
|
|
|
// Reorder among remaining active pins — archived chat
|
|
// should not interfere with position calculation.
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 1,
|
|
}))
|
|
// After reorder, ROW_NUMBER() compacts the sequence.
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 2,
|
|
second.ID: 0,
|
|
third.ID: 1,
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestChatPinOrderConstraints(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
bg := context.Background()
|
|
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
t.Run("ChildChatCannotBePinned", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
parent, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "parent",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
child, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "child",
|
|
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.PinChatByID(ctx, child.ID)
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderParentCheck))
|
|
})
|
|
|
|
t.Run("ArchivedChatCannotBePinned", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "will be archived",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
err = db.PinChatByID(ctx, chat.ID)
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderArchivedCheck))
|
|
})
|
|
}
|
|
|
|
func TestChatLabels(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
owner := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
t.Run("CreateWithLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
labels := database.StringMap{"github.repo": "coder/coder", "env": "prod"}
|
|
labelsJSON, err := json.Marshal(labels)
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "labeled-chat",
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.StringMap{"github.repo": "coder/coder", "env": "prod"}, chat.Labels)
|
|
require.Equal(t, owner.Username, chat.OwnerUsername)
|
|
require.Equal(t, owner.Name, chat.OwnerName)
|
|
|
|
// Read back and verify.
|
|
fetched, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat.Labels, fetched.Labels)
|
|
require.Equal(t, owner.Username, fetched.OwnerUsername)
|
|
require.Equal(t, owner.Name, fetched.OwnerName)
|
|
})
|
|
|
|
t.Run("CreateWithoutLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "no-labels-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
// Default should be an empty map, not nil.
|
|
require.NotNil(t, chat.Labels)
|
|
require.Empty(t, chat.Labels)
|
|
})
|
|
|
|
t.Run("ListReturnsOwnerFields", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "owner-fields-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetChats(ctx, database.GetChatsParams{OwnerID: owner.ID})
|
|
require.NoError(t, err)
|
|
|
|
chatIndex := slices.IndexFunc(rows, func(row database.GetChatsRow) bool {
|
|
return row.Chat.ID == chat.ID
|
|
})
|
|
require.NotEqual(t, -1, chatIndex, "chat not found in GetChats result")
|
|
require.Equal(t, owner.Username, rows[chatIndex].Chat.OwnerUsername)
|
|
require.Equal(t, owner.Name, rows[chatIndex].Chat.OwnerName)
|
|
})
|
|
|
|
t.Run("ChildrenReturnOwnerFields", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
parent, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "owner-fields-parent-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
child, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "owner-fields-child-" + uuid.NewString(),
|
|
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetChildChatsByParentIDs(ctx, database.GetChildChatsByParentIDsParams{
|
|
ParentIds: []uuid.UUID{parent.ID},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, child.ID, rows[0].Chat.ID)
|
|
require.Equal(t, owner.Username, rows[0].Chat.OwnerUsername)
|
|
require.Equal(t, owner.Name, rows[0].Chat.OwnerName)
|
|
})
|
|
|
|
t.Run("UpdateLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "update-labels-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, chat.Labels)
|
|
|
|
// Set labels.
|
|
newLabels, err := json.Marshal(database.StringMap{"team": "backend"})
|
|
require.NoError(t, err)
|
|
updated, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
|
ID: chat.ID,
|
|
Labels: newLabels,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.StringMap{"team": "backend"}, updated.Labels)
|
|
|
|
// Title should be unchanged.
|
|
require.Equal(t, "update-labels-chat", updated.Title)
|
|
|
|
// Clear labels by setting empty object.
|
|
emptyLabels, err := json.Marshal(database.StringMap{})
|
|
require.NoError(t, err)
|
|
cleared, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
|
ID: chat.ID,
|
|
Labels: emptyLabels,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, cleared.Labels)
|
|
})
|
|
|
|
t.Run("UpdateTitleDoesNotAffectLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
labels := database.StringMap{"pr": "1234"}
|
|
labelsJSON, err := json.Marshal(labels)
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "original-title",
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update title only — labels must survive.
|
|
updated, err := db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
|
ID: chat.ID,
|
|
Title: "new-title",
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "new-title", updated.Title)
|
|
require.Equal(t, database.StringMap{"pr": "1234"}, updated.Labels)
|
|
require.Equal(t, owner.Username, updated.OwnerUsername)
|
|
require.Equal(t, owner.Name, updated.OwnerName)
|
|
})
|
|
|
|
t.Run("FilterByLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create three chats with different labels.
|
|
for _, tc := range []struct {
|
|
title string
|
|
labels database.StringMap
|
|
}{
|
|
{"filter-a", database.StringMap{"env": "prod", "team": "backend"}},
|
|
{"filter-b", database.StringMap{"env": "prod", "team": "frontend"}},
|
|
{"filter-c", database.StringMap{"env": "staging"}},
|
|
} {
|
|
labelsJSON, err := json.Marshal(tc.labels)
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID, Title: tc.title,
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Filter by env=prod — should match filter-a and filter-b.
|
|
filterJSON, err := json.Marshal(database.StringMap{"env": "prod"})
|
|
require.NoError(t, err)
|
|
results, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
LabelFilter: pqtype.NullRawMessage{
|
|
RawMessage: filterJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
titles := make([]string, 0, len(results))
|
|
for _, c := range results {
|
|
titles = append(titles, c.Chat.Title)
|
|
}
|
|
require.Contains(t, titles, "filter-a")
|
|
require.Contains(t, titles, "filter-b")
|
|
require.NotContains(t, titles, "filter-c")
|
|
|
|
// Filter by env=prod AND team=backend — should match only filter-a.
|
|
filterJSON, err = json.Marshal(database.StringMap{"env": "prod", "team": "backend"})
|
|
require.NoError(t, err)
|
|
results, err = db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
LabelFilter: pqtype.NullRawMessage{
|
|
RawMessage: filterJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 1)
|
|
require.Equal(t, "filter-a", results[0].Chat.Title)
|
|
// No filter — should return all chats for this owner.
|
|
allChats, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(allChats), 3)
|
|
})
|
|
}
|
|
|
|
func TestUpdateChatLastTurnSummary(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
owner := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "summary-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affected, err := db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "resolved the issue", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
fetched, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sql.NullString{String: "resolved the issue", Valid: true}, fetched.LastTurnSummary)
|
|
require.Equal(t, chat.UpdatedAt, fetched.UpdatedAt)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: " \n\t ", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
fetched, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fetched.LastTurnSummary.Valid)
|
|
require.Equal(t, chat.UpdatedAt, fetched.UpdatedAt)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "fresh summary", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
advancedUpdatedAt := chat.UpdatedAt.Add(time.Second)
|
|
_, err = db.UpdateChatStatusPreserveUpdatedAt(ctx, database.UpdateChatStatusPreserveUpdatedAtParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
UpdatedAt: advancedUpdatedAt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "stale summary", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Zero(t, affected)
|
|
|
|
fetched, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary)
|
|
require.Equal(t, advancedUpdatedAt, fetched.UpdatedAt)
|
|
}
|
|
|
|
func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-rollback-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
const cutoff int64 = 50
|
|
|
|
affectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 10, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 5, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affectedByStepHistoryTipRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedByStepHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "interrupted",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 7, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// affectedByStepAssistantMsgRun: run-level fields are at/below
|
|
// the cutoff, but its step has assistant_message_id above the
|
|
// cutoff. This exercises the step.assistant_message_id > cutoff
|
|
// branch of the UNION independently of history_tip_message_id.
|
|
affectedByStepAssistantMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedByStepAssistantMsgRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff + 3, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
unaffectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
unaffectedStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: unaffectedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: cutoff,
|
|
StartedBefore: time.Now().Add(time.Minute),
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 3, deletedRows)
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedSteps)
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedByStepHistoryTipRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedByStepHistoryTipSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedByStepHistoryTipSteps)
|
|
|
|
// Verify the run caught by step-level assistant_message_id is
|
|
// also deleted. This would survive if the
|
|
// step.assistant_message_id > @message_id clause were removed.
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedByStepAssistantMsgRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedByStepAssistantMsgSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepAssistantMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedByStepAssistantMsgSteps)
|
|
|
|
remainingRuns, err := store.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID,
|
|
LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingRuns, 1)
|
|
require.Equal(t, unaffectedRun.ID, remainingRuns[0].ID)
|
|
|
|
remainingRun, err := store.GetChatDebugRunByID(ctx, unaffectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, unaffectedRun.ID, remainingRun.ID)
|
|
|
|
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, unaffectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingSteps, 1)
|
|
require.Equal(t, unaffectedStep.ID, remainingSteps[0].ID)
|
|
}
|
|
|
|
// TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls
|
|
// verifies that DeleteChatDebugDataAfterMessageID handles step-level
|
|
// field boundaries and NULL combinations when run-level message IDs are
|
|
// below the cutoff. This complements the triggered-runs test with extra
|
|
// coverage for strict step-level comparisons and SQL NULL behavior.
|
|
func TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-step-boundaries-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-step-boundaries-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
const cutoff int64 = 100
|
|
|
|
// insertRunBelowRunLevelCutoff creates a run whose run-level message
|
|
// IDs cannot match the deletion query. The step-level fields decide
|
|
// whether the run is deleted.
|
|
insertRunBelowRunLevelCutoff := func(t *testing.T) database.ChatDebugRun {
|
|
t.Helper()
|
|
run, runErr := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff - 10, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 10, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, runErr)
|
|
return run
|
|
}
|
|
|
|
// assistantAboveWithNullHistoryTipRun is deleted only through the
|
|
// step.assistant_message_id clause.
|
|
assistantAboveWithNullHistoryTipRun := insertRunBelowRunLevelCutoff(t)
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: assistantAboveWithNullHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff + 5, Valid: true},
|
|
// HistoryTipMessageID intentionally omitted (NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Add a nonmatching step to verify that one matching step is enough
|
|
// to delete the run and cascade all of its steps.
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: assistantAboveWithNullHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 2,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff - 5, Valid: true},
|
|
// HistoryTipMessageID intentionally omitted (NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// assistantAboveWithHistoryTipBelowRun is deleted through the
|
|
// step.assistant_message_id clause while the step history tip stays
|
|
// below the cutoff.
|
|
assistantAboveWithHistoryTipBelowRun := insertRunBelowRunLevelCutoff(t)
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: assistantAboveWithHistoryTipBelowRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff + 20, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 3, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// assistantBelowWithNullHistoryTipRun survives because its step
|
|
// assistant_message_id is below the cutoff and step history tip is
|
|
// NULL.
|
|
assistantBelowWithNullHistoryTipRun := insertRunBelowRunLevelCutoff(t)
|
|
assistantBelowWithNullHistoryTipStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: assistantBelowWithNullHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff - 3, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// assistantAtBoundaryWithNullHistoryTipRun survives because the
|
|
// query uses strict greater-than, not greater-than-or-equal.
|
|
assistantAtBoundaryWithNullHistoryTipRun := insertRunBelowRunLevelCutoff(t)
|
|
assistantAtBoundaryWithNullHistoryTipStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: assistantAtBoundaryWithNullHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// historyTipAboveWithNullAssistantRun is deleted through the
|
|
// step.history_tip_message_id clause while assistant_message_id is
|
|
// NULL.
|
|
historyTipAboveWithNullAssistantRun := insertRunBelowRunLevelCutoff(t)
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: historyTipAboveWithNullAssistantRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 2, Valid: true},
|
|
// AssistantMessageID intentionally omitted (NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// historyTipAtBoundaryWithNullAssistantRun survives because the
|
|
// step history tip uses strict greater-than, not greater-than-or-equal.
|
|
historyTipAtBoundaryWithNullAssistantRun := insertRunBelowRunLevelCutoff(t)
|
|
historyTipAtBoundaryWithNullAssistantStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: historyTipAtBoundaryWithNullAssistantRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
// AssistantMessageID intentionally omitted (NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// bothStepMessageIDsNullRun survives because NULL > N evaluates to
|
|
// NULL, not TRUE, in SQL.
|
|
bothStepMessageIDsNullRun := insertRunBelowRunLevelCutoff(t)
|
|
bothStepMessageIDsNullStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: bothStepMessageIDsNullRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
// Both message IDs intentionally omitted (NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: cutoff,
|
|
StartedBefore: time.Now().Add(time.Minute),
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 3, deletedRows)
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, assistantAboveWithNullHistoryTipRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"assistant above cutoff with NULL history tip must be deleted")
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, assistantAboveWithHistoryTipBelowRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"assistant above cutoff with history tip below cutoff must be deleted")
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, historyTipAboveWithNullAssistantRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"NULL assistant with history tip above cutoff must be deleted")
|
|
|
|
for _, deletedRun := range []struct {
|
|
name string
|
|
id uuid.UUID
|
|
}{
|
|
{name: "assistant above cutoff with NULL history tip", id: assistantAboveWithNullHistoryTipRun.ID},
|
|
{name: "assistant above cutoff with history tip below cutoff", id: assistantAboveWithHistoryTipBelowRun.ID},
|
|
{name: "NULL assistant with history tip above cutoff", id: historyTipAboveWithNullAssistantRun.ID},
|
|
} {
|
|
steps, stepsErr := store.GetChatDebugStepsByRunID(ctx, deletedRun.id)
|
|
require.NoError(t, stepsErr, "%s: get cascaded steps", deletedRun.name)
|
|
require.Empty(t, steps, "%s: deleted run steps must cascade", deletedRun.name)
|
|
}
|
|
|
|
remainingAssistantBelowRun, err := store.GetChatDebugRunByID(ctx, assistantBelowWithNullHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, assistantBelowWithNullHistoryTipRun.ID, remainingAssistantBelowRun.ID,
|
|
"assistant below cutoff with NULL history tip must survive")
|
|
|
|
remainingAssistantAtBoundaryRun, err := store.GetChatDebugRunByID(ctx, assistantAtBoundaryWithNullHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, assistantAtBoundaryWithNullHistoryTipRun.ID, remainingAssistantAtBoundaryRun.ID,
|
|
"assistant at cutoff boundary with NULL history tip must survive")
|
|
|
|
remainingHistoryTipAtBoundaryRun, err := store.GetChatDebugRunByID(ctx, historyTipAtBoundaryWithNullAssistantRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, historyTipAtBoundaryWithNullAssistantRun.ID, remainingHistoryTipAtBoundaryRun.ID,
|
|
"history tip at cutoff boundary with NULL assistant must survive")
|
|
|
|
remainingBothStepMessageIDsNullRun, err := store.GetChatDebugRunByID(ctx, bothStepMessageIDsNullRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, bothStepMessageIDsNullRun.ID, remainingBothStepMessageIDsNullRun.ID,
|
|
"both step message IDs NULL must survive")
|
|
|
|
assistantBelowSteps, err := store.GetChatDebugStepsByRunID(ctx, assistantBelowWithNullHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, assistantBelowSteps, 1)
|
|
require.Equal(t, assistantBelowWithNullHistoryTipStep.ID, assistantBelowSteps[0].ID)
|
|
|
|
assistantAtBoundarySteps, err := store.GetChatDebugStepsByRunID(ctx, assistantAtBoundaryWithNullHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, assistantAtBoundarySteps, 1)
|
|
require.Equal(t, assistantAtBoundaryWithNullHistoryTipStep.ID, assistantAtBoundarySteps[0].ID)
|
|
|
|
historyTipAtBoundarySteps, err := store.GetChatDebugStepsByRunID(ctx, historyTipAtBoundaryWithNullAssistantRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, historyTipAtBoundarySteps, 1)
|
|
require.Equal(t, historyTipAtBoundaryWithNullAssistantStep.ID, historyTipAtBoundarySteps[0].ID)
|
|
|
|
bothStepMessageIDsNullSteps, err := store.GetChatDebugStepsByRunID(ctx, bothStepMessageIDsNullRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, bothStepMessageIDsNullSteps, 1)
|
|
require.Equal(t, bothStepMessageIDsNullStep.ID, bothStepMessageIDsNullSteps[0].ID)
|
|
|
|
remaining, err := store.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID,
|
|
LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, 4)
|
|
}
|
|
|
|
func TestFinalizeStaleChatDebugRows(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-finalize-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-finalize-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleTime is well before the threshold so rows stamped with it
|
|
// are considered stale. The threshold sits between staleTime and
|
|
// NOW(), letting us create rows that are stale-by-age and rows
|
|
// that are fresh-by-age in the same test.
|
|
staleTime := time.Now().Add(-2 * time.Hour)
|
|
staleThreshold := time.Now().Add(-1 * time.Hour)
|
|
|
|
// preExistingError is attached to staleStep so we can verify
|
|
// that finalization preserves pre-existing error JSON rather
|
|
// than clearing or overwriting it.
|
|
preExistingError := json.RawMessage(`{"code":"timeout","message":"upstream deadline exceeded"}`)
|
|
|
|
// --- staleRun: in_progress run with no finished_at --- should be
|
|
// finalized.
|
|
staleRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleStep: in_progress step attached to staleRun with a
|
|
// pre-existing error JSON payload.
|
|
staleStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: staleRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
Error: pqtype.NullRawMessage{
|
|
RawMessage: preExistingError,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, staleStep.Error.Valid,
|
|
"precondition: error must be stored at insertion")
|
|
|
|
// --- orphanStep: in_progress step whose run is already completed ---
|
|
// Its own updated_at is old, so it should be finalized directly.
|
|
// The step must be inserted while the run is still open because
|
|
// InsertChatDebugStep requires finished_at IS NULL on the parent
|
|
// run (atomic guard against appending steps to finalized runs).
|
|
completedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert the step while the run is still open (finished_at IS NULL).
|
|
orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: completedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now mark the run as completed with a finished_at timestamp,
|
|
// leaving the step orphaned in in_progress state.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: completedRun.ID,
|
|
ChatID: completedRun.ChatID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- cascadeRun: stale in_progress run with a FRESH step ---
|
|
// The run's updated_at is old so the run itself is finalized by
|
|
// age. The step's updated_at is recent (default NOW()), so it is
|
|
// NOT caught by the age predicate. It must be finalized solely
|
|
// via the cascade CTE clause: run_id IN (SELECT id FROM
|
|
// finalized_runs). Removing that clause would leave this step
|
|
// stuck in 'in_progress'.
|
|
cascadeRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// cascadeStep: recent updated_at (default NOW()), so only the
|
|
// cascade path can finalize it.
|
|
cascadeStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: cascadeRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The InsertChatDebugStep CTE atomically bumps the parent run's
|
|
// updated_at to NOW(). Reset it back to staleTime so the run is
|
|
// still caught by the age predicate in FinalizeStaleChatDebugRows.
|
|
err = store.TouchChatDebugRunUpdatedAt(ctx, database.TouchChatDebugRunUpdatedAtParams{
|
|
ID: cascadeRun.ID,
|
|
ChatID: chat.ID,
|
|
Now: staleTime,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- alreadyDone: completed run/step --- should NOT be touched.
|
|
doneRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert step while run is still open.
|
|
doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: doneRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now finalize both run and step.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: doneRun.ID,
|
|
ChatID: doneRun.ChatID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: doneStep.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- errorRun: error run/step --- should NOT be touched either,
|
|
// exercising the 'error' branch of the NOT IN clause.
|
|
errorRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "error",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert step while run is still open.
|
|
errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: errorRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "error",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now finalize both run and step.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: errorRun.ID,
|
|
ChatID: errorRun.ChatID,
|
|
Status: sql.NullString{String: "error", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: errorStep.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "error", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- freshRun: recent in_progress run with current timestamp ---
|
|
// should NOT be finalized because its updated_at is after the
|
|
// threshold, exercising the age predicate (not just terminal
|
|
// status) as the survival reason.
|
|
freshRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 20, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 20, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
// UpdatedAt defaults to NOW(), which is after staleThreshold.
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
freshStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: freshRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
// UpdatedAt defaults to NOW().
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- Execute the finalization sweep. ---
|
|
// Capture the @now timestamp so we can verify finalized rows
|
|
// received exactly this value for updated_at and finished_at.
|
|
nowParam := time.Now().Truncate(time.Microsecond)
|
|
result, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{
|
|
Now: nowParam,
|
|
UpdatedBefore: staleThreshold,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleRun + cascadeRun were finalized; completedRun and doneRun
|
|
// were already terminal, and freshRun survives because its
|
|
// updated_at is after the threshold — so only 2 runs are expected.
|
|
assert.EqualValues(t, 2, result.RunsFinalized,
|
|
"stale + cascade in_progress runs should be finalized")
|
|
// staleStep (age), orphanStep (age), cascadeStep (cascade only)
|
|
// should all be finalized.
|
|
assert.EqualValues(t, 3, result.StepsFinalized,
|
|
"stale step + orphan step + cascade step should all be finalized")
|
|
|
|
// Verify the stale run was set to interrupted with correct
|
|
// timestamps matching the @now parameter.
|
|
updatedStaleRun, err := store.GetChatDebugRunByID(ctx, staleRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "interrupted", updatedStaleRun.Status)
|
|
assert.True(t, updatedStaleRun.FinishedAt.Valid,
|
|
"finalized run should have a finished_at timestamp")
|
|
assert.WithinDuration(t, nowParam, updatedStaleRun.FinishedAt.Time, time.Microsecond,
|
|
"finished_at should match the @now parameter")
|
|
assert.WithinDuration(t, nowParam, updatedStaleRun.UpdatedAt, time.Microsecond,
|
|
"updated_at should match the @now parameter")
|
|
|
|
// Verify the stale step was set to interrupted and its
|
|
// pre-existing error JSON was preserved.
|
|
staleSteps, err := store.GetChatDebugStepsByRunID(ctx, staleRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, staleSteps, 1)
|
|
assert.Equal(t, staleStep.ID, staleSteps[0].ID)
|
|
assert.Equal(t, "interrupted", staleSteps[0].Status)
|
|
assert.True(t, staleSteps[0].FinishedAt.Valid,
|
|
"finalized step should have a finished_at timestamp")
|
|
assert.WithinDuration(t, nowParam, staleSteps[0].FinishedAt.Time, time.Microsecond,
|
|
"step finished_at should match the @now parameter")
|
|
assert.WithinDuration(t, nowParam, staleSteps[0].UpdatedAt, time.Microsecond,
|
|
"step updated_at should match the @now parameter")
|
|
// The error JSON that was set at insertion time must survive
|
|
// finalization. The query does not touch the error column, so
|
|
// this proves the JSONB payload is preserved.
|
|
assert.True(t, staleSteps[0].Error.Valid,
|
|
"pre-existing error JSON must be preserved after finalization")
|
|
assert.JSONEq(t, string(preExistingError), string(staleSteps[0].Error.RawMessage),
|
|
"error JSON content must match the value set at insertion")
|
|
|
|
// Verify the orphan step was also finalized with correct timestamps.
|
|
orphanSteps, err := store.GetChatDebugStepsByRunID(ctx, completedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, orphanSteps, 1)
|
|
assert.Equal(t, orphanStep.ID, orphanSteps[0].ID)
|
|
assert.Equal(t, "interrupted", orphanSteps[0].Status)
|
|
assert.True(t, orphanSteps[0].FinishedAt.Valid,
|
|
"orphan step should have a finished_at timestamp")
|
|
assert.WithinDuration(t, nowParam, orphanSteps[0].FinishedAt.Time, time.Microsecond,
|
|
"orphan step finished_at should match the @now parameter")
|
|
assert.WithinDuration(t, nowParam, orphanSteps[0].UpdatedAt, time.Microsecond,
|
|
"orphan step updated_at should match the @now parameter")
|
|
// The orphan step had no error set; verify it remains null.
|
|
assert.False(t, orphanSteps[0].Error.Valid,
|
|
"step without pre-existing error should remain null after finalization")
|
|
|
|
// Verify the cascade run was finalized with correct timestamps.
|
|
updatedCascadeRun, err := store.GetChatDebugRunByID(ctx, cascadeRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "interrupted", updatedCascadeRun.Status)
|
|
assert.True(t, updatedCascadeRun.FinishedAt.Valid,
|
|
"cascade run should have a finished_at timestamp")
|
|
assert.WithinDuration(t, nowParam, updatedCascadeRun.FinishedAt.Time, time.Microsecond,
|
|
"cascade run finished_at should match the @now parameter")
|
|
assert.WithinDuration(t, nowParam, updatedCascadeRun.UpdatedAt, time.Microsecond,
|
|
"cascade run updated_at should match the @now parameter")
|
|
|
|
// Verify the cascade step was finalized despite its recent
|
|
// updated_at, proving the cascade CTE clause is required.
|
|
cascadeSteps, err := store.GetChatDebugStepsByRunID(ctx, cascadeRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, cascadeSteps, 1)
|
|
assert.Equal(t, cascadeStep.ID, cascadeSteps[0].ID)
|
|
assert.Equal(t, "interrupted", cascadeSteps[0].Status,
|
|
"fresh step should be finalized via cascade, not age")
|
|
assert.True(t, cascadeSteps[0].FinishedAt.Valid,
|
|
"cascade step should have a finished_at timestamp")
|
|
assert.WithinDuration(t, nowParam, cascadeSteps[0].FinishedAt.Time, time.Microsecond,
|
|
"cascade step finished_at should match the @now parameter")
|
|
assert.WithinDuration(t, nowParam, cascadeSteps[0].UpdatedAt, time.Microsecond,
|
|
"cascade step updated_at should match the @now parameter")
|
|
// The cascade step also had no error set.
|
|
assert.False(t, cascadeSteps[0].Error.Valid,
|
|
"cascade step without pre-existing error should remain null")
|
|
|
|
// Verify the completed run/step are untouched.
|
|
unchangedRun, err := store.GetChatDebugRunByID(ctx, doneRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "completed", unchangedRun.Status)
|
|
|
|
doneSteps, err := store.GetChatDebugStepsByRunID(ctx, doneRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, doneSteps, 1)
|
|
assert.Equal(t, "completed", doneSteps[0].Status)
|
|
|
|
// Verify the error run/step are untouched.
|
|
unchangedErrorRun, err := store.GetChatDebugRunByID(ctx, errorRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "error", unchangedErrorRun.Status)
|
|
|
|
errorSteps, err := store.GetChatDebugStepsByRunID(ctx, errorRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, errorSteps, 1)
|
|
assert.Equal(t, "error", errorSteps[0].Status)
|
|
|
|
// Verify the fresh in_progress run survived due to recency,
|
|
// not terminal status — its updated_at is after the threshold.
|
|
unchangedFreshRun, err := store.GetChatDebugRunByID(ctx, freshRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "in_progress", unchangedFreshRun.Status,
|
|
"fresh in_progress run must survive due to recency")
|
|
assert.False(t, unchangedFreshRun.FinishedAt.Valid,
|
|
"fresh run should not have a finished_at timestamp")
|
|
|
|
freshSteps, err := store.GetChatDebugStepsByRunID(ctx, freshRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, freshSteps, 1)
|
|
assert.Equal(t, freshStep.ID, freshSteps[0].ID)
|
|
assert.Equal(t, "in_progress", freshSteps[0].Status,
|
|
"fresh in_progress step must survive due to recency")
|
|
assert.False(t, freshSteps[0].FinishedAt.Valid,
|
|
"fresh step should not have a finished_at timestamp")
|
|
|
|
// A second sweep should be a no-op.
|
|
result2, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{
|
|
Now: time.Now(),
|
|
UpdatedBefore: staleThreshold,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 0, result2.RunsFinalized,
|
|
"second sweep should find nothing to finalize")
|
|
assert.EqualValues(t, 0, result2.StepsFinalized,
|
|
"second sweep should find nothing to finalize")
|
|
}
|
|
|
|
func TestChatDebugSQLGuards(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-guards-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatA, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-guard-A-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatB, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-guard-B-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
runA, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chatA.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
stepA, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: runA.ID,
|
|
ChatID: chatA.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// InsertChatDebugStep: valid run_id but chat_id belongs to a
|
|
// different chat. The INSERT...SELECT guard should produce zero
|
|
// rows, surfacing as sql.ErrNoRows.
|
|
t.Run("InsertChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: runA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
StepNumber: 2,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"InsertChatDebugStep should fail when chat_id does not match the run's chat_id")
|
|
})
|
|
|
|
// UpdateChatDebugRun: valid run ID but wrong chat_id.
|
|
t.Run("UpdateChatDebugRun_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: runA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"UpdateChatDebugRun should fail when chat_id does not match")
|
|
})
|
|
|
|
// UpdateChatDebugStep: valid step ID but wrong chat_id.
|
|
t.Run("UpdateChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: stepA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"UpdateChatDebugStep should fail when chat_id does not match")
|
|
})
|
|
}
|
|
|
|
// TestChatDebugRunCOALESCEPreservation verifies that the COALESCE
|
|
// pattern in UpdateChatDebugRun preserves every field that was not
|
|
// explicitly supplied in the update. If COALESCE were removed from
|
|
// any column, the corresponding field would silently null out.
|
|
func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-coalesce-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-coalesce-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rootChatID := uuid.New()
|
|
parentChatID := uuid.New()
|
|
|
|
// Insert a fully-populated run so every nullable field has a value.
|
|
original, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
|
|
ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 42, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 41, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
Summary: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"key":"val"}`), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update only Status and FinishedAt. Every other nullable param
|
|
// is left as its Go zero value (Valid: false → SQL NULL), which
|
|
// the COALESCE pattern should interpret as "keep existing."
|
|
now := time.Now()
|
|
updated, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: original.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
Now: now,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Status and FinishedAt should be updated.
|
|
require.Equal(t, "completed", updated.Status)
|
|
require.True(t, updated.FinishedAt.Valid)
|
|
|
|
// UpdatedAt should be set to the @now value we passed in.
|
|
require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond,
|
|
"updated_at should equal the @now parameter")
|
|
|
|
// Every field not in the update call must be preserved exactly.
|
|
require.Equal(t, original.RootChatID, updated.RootChatID,
|
|
"RootChatID should survive a partial update")
|
|
require.Equal(t, original.ParentChatID, updated.ParentChatID,
|
|
"ParentChatID should survive a partial update")
|
|
require.Equal(t, original.ModelConfigID, updated.ModelConfigID,
|
|
"ModelConfigID should survive a partial update")
|
|
require.Equal(t, original.TriggerMessageID, updated.TriggerMessageID,
|
|
"TriggerMessageID should survive a partial update")
|
|
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
|
"HistoryTipMessageID should survive a partial update")
|
|
require.Equal(t, original.Provider, updated.Provider,
|
|
"Provider should survive a partial update")
|
|
require.Equal(t, original.Model, updated.Model,
|
|
"Model should survive a partial update")
|
|
require.JSONEq(t, string(original.Summary), string(updated.Summary),
|
|
"Summary should survive a partial update")
|
|
require.Equal(t, original.Kind, updated.Kind,
|
|
"Kind should survive a partial update")
|
|
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
|
"StartedAt should survive a partial update")
|
|
}
|
|
|
|
// TestChatDebugStepCOALESCEPreservation verifies that the COALESCE
|
|
// pattern in UpdateChatDebugStep preserves every field that was not
|
|
// explicitly supplied in the update. If COALESCE were removed from
|
|
// any column, the corresponding field would silently null out.
|
|
func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-step-coalesce-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-step-coalesce-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
run, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a fully-populated step so every nullable field has a value.
|
|
original, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: run.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "llm_call",
|
|
Status: "in_progress",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
AssistantMessageID: sql.NullInt64{Int64: 11, Valid: true},
|
|
NormalizedRequest: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"prompt":"hello"}`), Valid: true},
|
|
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"world"}`), Valid: true},
|
|
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":42}`), Valid: true},
|
|
Attempts: pqtype.NullRawMessage{RawMessage: json.RawMessage(`[{"n":1}]`), Valid: true},
|
|
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"transient"}`), Valid: true},
|
|
Metadata: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"trace_id":"abc"}`), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update only Status and FinishedAt. Every other nullable param
|
|
// is left as its Go zero value (Valid: false -> SQL NULL), which
|
|
// the COALESCE pattern should interpret as "keep existing."
|
|
now := time.Now()
|
|
updated, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: original.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
Now: now,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Status and FinishedAt should be updated.
|
|
require.Equal(t, "completed", updated.Status)
|
|
require.True(t, updated.FinishedAt.Valid)
|
|
|
|
// UpdatedAt should be set to the @now value we passed in.
|
|
require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond,
|
|
"updated_at should equal the @now parameter")
|
|
|
|
// Every field not in the update call must be preserved exactly.
|
|
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
|
"HistoryTipMessageID should survive a partial update")
|
|
require.Equal(t, original.AssistantMessageID, updated.AssistantMessageID,
|
|
"AssistantMessageID should survive a partial update")
|
|
require.JSONEq(t, string(original.NormalizedRequest), string(updated.NormalizedRequest),
|
|
"NormalizedRequest should survive a partial update")
|
|
require.JSONEq(t, string(original.NormalizedResponse.RawMessage), string(updated.NormalizedResponse.RawMessage),
|
|
"NormalizedResponse should survive a partial update")
|
|
require.JSONEq(t, string(original.Usage.RawMessage), string(updated.Usage.RawMessage),
|
|
"Usage should survive a partial update")
|
|
require.JSONEq(t, string(original.Attempts), string(updated.Attempts),
|
|
"Attempts should survive a partial update")
|
|
require.JSONEq(t, string(original.Error.RawMessage), string(updated.Error.RawMessage),
|
|
"Error should survive a partial update")
|
|
require.JSONEq(t, string(original.Metadata), string(updated.Metadata),
|
|
"Metadata should survive a partial update")
|
|
require.Equal(t, original.Operation, updated.Operation,
|
|
"Operation should survive a partial update")
|
|
require.Equal(t, original.StepNumber, updated.StepNumber,
|
|
"StepNumber should survive a partial update")
|
|
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
|
"StartedAt should survive a partial update")
|
|
}
|
|
|
|
// TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive verifies
|
|
// that runs whose message ID columns are all NULL are never matched
|
|
// by DeleteChatDebugDataAfterMessageID. SQL's three-valued logic
|
|
// means NULL > N evaluates to NULL (not TRUE), so these rows must
|
|
// survive. Without this test a future change could break the
|
|
// invariant with no test failure.
|
|
func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-null-msg-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-null-msg-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a run with all message ID columns left as NULL (Valid: false).
|
|
nullMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
// TriggerMessageID and HistoryTipMessageID intentionally
|
|
// omitted (zero-value → SQL NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Attach a step with NULL message IDs too.
|
|
nullMsgStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: nullMsgRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
// HistoryTipMessageID and AssistantMessageID intentionally
|
|
// omitted (zero-value → SQL NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Delete with an arbitrary cutoff. The run and its step should
|
|
// survive because NULL > cutoff evaluates to NULL, not TRUE.
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: 1,
|
|
StartedBefore: time.Now().Add(time.Minute),
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, deletedRows, "rows with NULL message IDs must not be deleted")
|
|
|
|
// Verify run still exists.
|
|
remaining, err := store.GetChatDebugRunByID(ctx, nullMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, nullMsgRun.ID, remaining.ID)
|
|
|
|
// Verify step still exists.
|
|
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, nullMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingSteps, 1)
|
|
require.Equal(t, nullMsgStep.ID, remainingSteps[0].ID)
|
|
}
|
|
|
|
// TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns
|
|
// verifies the started_before bound on DeleteChatDebugDataAfterMessageID.
|
|
// The bound exists so that retried cleanup (e.g. after edit or archive)
|
|
// cannot delete runs started by a replacement turn that races ahead of
|
|
// the retry window. Without this filter, a stale cleanup would wipe
|
|
// fresh debug rows.
|
|
func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-started-before-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-started-before-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
const cutoff int64 = 50
|
|
|
|
// oldRun started an hour ago: must be deleted because it started
|
|
// before the bound.
|
|
oldStartedAt := time.Now().Add(-1 * time.Hour).UTC().
|
|
Truncate(time.Microsecond)
|
|
oldRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Bound sits between the two runs. Any run whose started_at is at
|
|
// or after this instant must survive.
|
|
cutoffTime := time.Now().Add(-30 * time.Minute).UTC().
|
|
Truncate(time.Microsecond)
|
|
|
|
// newRun started after cutoffTime with identical message_id values
|
|
// that would otherwise match the delete predicate. It must survive
|
|
// because started_before excludes it.
|
|
newStartedAt := time.Now().UTC().Truncate(time.Microsecond)
|
|
newRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: cutoff,
|
|
StartedBefore: cutoffTime,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, deletedRows,
|
|
"only the pre-cutoff run should be deleted")
|
|
|
|
// oldRun must be gone.
|
|
_, err = store.GetChatDebugRunByID(ctx, oldRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
// newRun must survive the retry window.
|
|
remaining, err := store.GetChatDebugRunByID(ctx, newRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, newRun.ID, remaining.ID)
|
|
}
|
|
|
|
// TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns verifies
|
|
// the started_before bound on DeleteChatDebugDataByChatID. Archive
|
|
// cleanup retries rely on this bound to avoid deleting runs created
|
|
// by a replacement turn that starts after an unarchive races ahead of
|
|
// the retry window.
|
|
func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-by-chat-started-before-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-by-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
oldStartedAt := time.Now().Add(-1 * time.Hour).UTC().
|
|
Truncate(time.Microsecond)
|
|
oldRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
cutoffTime := time.Now().Add(-30 * time.Minute).UTC().
|
|
Truncate(time.Microsecond)
|
|
|
|
newStartedAt := time.Now().UTC().Truncate(time.Microsecond)
|
|
newRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataByChatID(ctx, database.DeleteChatDebugDataByChatIDParams{
|
|
ChatID: chat.ID,
|
|
StartedBefore: cutoffTime,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, deletedRows,
|
|
"only the pre-cutoff run should be deleted")
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, oldRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
remaining, err := store.GetChatDebugRunByID(ctx, newRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, newRun.ID, remaining.ID)
|
|
}
|
|
|
|
func TestChatHasUnread(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model-" + uuid.NewString(),
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
getHasUnread := func() bool {
|
|
rows, err := store.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, row := range rows {
|
|
if row.Chat.ID == chat.ID {
|
|
return row.HasUnread
|
|
}
|
|
}
|
|
t.Fatal("chat not found in GetChats result")
|
|
return false
|
|
}
|
|
|
|
// New chat with no messages: not unread.
|
|
require.False(t, getHasUnread(), "new chat with no messages should not be unread")
|
|
|
|
// Helper to insert a single chat message.
|
|
insertMsg := func(role database.ChatMessageRole, text string) {
|
|
t.Helper()
|
|
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{user.ID},
|
|
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
|
Role: []database.ChatMessageRole{role},
|
|
Content: []string{fmt.Sprintf(`[{"type":"text","text":%q}]`, text)},
|
|
ContentVersion: []int16{0},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Insert an assistant message: becomes unread.
|
|
insertMsg(database.ChatMessageRoleAssistant, "hello")
|
|
require.True(t, getHasUnread(), "chat with unread assistant message should be unread")
|
|
|
|
// Mark as read: no longer unread.
|
|
lastMsg, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
|
ChatID: chat.ID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
require.NoError(t, err)
|
|
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
|
ID: chat.ID,
|
|
LastReadMessageID: lastMsg.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, getHasUnread(), "chat should not be unread after marking as read")
|
|
|
|
// Insert another assistant message: becomes unread again.
|
|
insertMsg(database.ChatMessageRoleAssistant, "new message")
|
|
require.True(t, getHasUnread(), "new assistant message after read should be unread")
|
|
|
|
// Mark as read again, then verify user messages don't
|
|
// trigger unread.
|
|
lastMsg, err = store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
|
ChatID: chat.ID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
require.NoError(t, err)
|
|
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
|
ID: chat.ID,
|
|
LastReadMessageID: lastMsg.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
insertMsg(database.ChatMessageRoleUser, "user msg")
|
|
require.False(t, getHasUnread(), "user messages should not trigger unread")
|
|
}
|