mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
0bfb9f6f13
Persists the agent-generated turn-end summary on `chats` and shows it as the Agents sidebar subtitle when present, falling back to the model name. Errors still take precedence. > Mux is acting on Mike's behalf. ## What changes **Storage.** New nullable `last_turn_summary` column on `chats` (migration `000486`). New `UpdateChatLastTurnSummary` query normalizes blank/whitespace input to `NULL`, preserves `updated_at` (so the chat does not jump to the top of the sidebar on summary writes), and uses an `expected_updated_at` stale-write guard so an older async summary cannot overwrite a newer turn. **Backend.** `coderd/x/chatd/chatd.go` decouples summary generation from webpush. Generated summaries persist for completed parent turns even when webpush is unconfigured or has no subscriptions. The same generated text is reused as the webpush body when webpush is configured, so the summary model is not called twice. Generic fallback push text is no longer persisted; it clears any stale summary instead. Error/interrupt/pending-action terminal paths clear `last_turn_summary` for the latest turn. **Frontend.** `AgentsSidebar.tsx` subtitle priority is now `errorReason || lastTurnSummary || modelName`, normalized via the existing `asNonEmptyString` helper from `blockUtils.ts`. ## Tests - `TestUpdateChatLastTurnSummary` (database): success, whitespace-to-NULL, stale guard rejects, `updated_at` preserved. - `TestUpdateLastTurnSummaryRejectsStaleWrites` (chatd internal): direct stale-`expected_updated_at` test. - `TestSuccessfulChatPersistsTurnSummaryWithoutWebPush`: persistence works without webpush subscriptions. - `TestSuccessfulChatSendsWebPushWithSummary`: same generated text drives both DB and push body. - `TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText`: fallback text is not persisted. - `TestErroredChatClearsLastTurnSummaryAndSendsWebPush`: error path clears the field. - `TestInterruptChatDoesNotSendWebPushNotification`: interrupt path clears the field, no push fires. - `AgentsSidebar.test.tsx`: subtitle priority for summary-present, error-wins, no-summary fallback, whitespace fallback. - `AgentsSidebar.stories.tsx`: `ChatWithTurnSummary` and `ChatWithTurnSummaryAndError`. ## Notes - No backfill. Existing chats keep showing the model name until their next turn completes. - Parent chats only in this iteration; the field is rendered on any `Chat` if a future change extends generation to children. - Decoupling generation from webpush adds quickgen model calls for completed parent turns that previously skipped generation when no subscriptions existed. Existing parent-only, assistant-text-present, `PushSummaryModel` configured, and bounded-timeout gates keep this behavior bounded.
12928 lines
429 KiB
Go
12928 lines
429 KiB
Go
package database_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbfake"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/coderd/database/migrations"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/provisionersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestGetDeploymentWorkspaceAgentStats(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
stats, err := db.GetDeploymentWorkspaceAgentStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(2), stats.SessionCountVSCode)
|
|
})
|
|
|
|
t.Run("GroupsByAgentID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
insertTime := dbtime.Now()
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Second),
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
// Ensure this stat is newer!
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
stats, err := db.GetDeploymentWorkspaceAgentStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(1), stats.SessionCountVSCode)
|
|
})
|
|
}
|
|
|
|
func TestGetDeploymentWorkspaceAgentUsageStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
// Old stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountSSH: 4,
|
|
SessionCountVSCode: 3,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID,
|
|
SessionCountReconnectingPTY: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
stats, err := db.GetDeploymentWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(2), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(2), stats.WorkspaceRxBytes)
|
|
require.Equal(t, 1.5, stats.WorkspaceConnectionLatency50)
|
|
require.Equal(t, 1.95, stats.WorkspaceConnectionLatency95)
|
|
require.Equal(t, int64(1), stats.SessionCountVSCode)
|
|
require.Equal(t, int64(1), stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), stats.SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats.SessionCountJetBrains)
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
agentID := uuid.New()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 3,
|
|
RxBytes: 4,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
|
|
stats, err := db.GetDeploymentWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, int64(3), stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(4), stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(0), stats.SessionCountVSCode)
|
|
require.Equal(t, int64(0), stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), stats.SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats.SessionCountJetBrains)
|
|
})
|
|
}
|
|
|
|
func TestGetEligibleProvisionerDaemonsByProvisionerJobIDs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoJobsReturnsEmpty", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{})
|
|
require.NoError(t, err)
|
|
require.Empty(t, daemons)
|
|
})
|
|
|
|
t.Run("MatchesProvisionerType", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
matchingDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, matchingDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesOrganizationScope", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
provisionersdk.TagOwner: "",
|
|
},
|
|
})
|
|
|
|
orgDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "org-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
provisionersdk.TagOwner: "",
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "user-daemon",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeUser,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, orgDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesMultipleProvisioners", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-1",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemon2 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-2",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "daemon-3",
|
|
OrganizationID: org.ID,
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeTerraform},
|
|
Tags: database.StringMap{
|
|
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetEligibleProvisionerDaemonsByProvisionerJobIDs(context.Background(), []uuid.UUID{job.ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
|
|
daemonIDs := []uuid.UUID{daemons[0].ProvisionerDaemon.ID, daemons[1].ProvisionerDaemon.ID}
|
|
require.ElementsMatch(t, []uuid.UUID{daemon1.ID, daemon2.ID}, daemonIDs)
|
|
})
|
|
}
|
|
|
|
func TestGetProvisionerDaemonsWithStatusByOrganization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoDaemonsInOrgReturnsEmpty", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
otherOrg := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: otherOrg.ID,
|
|
})
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, daemons)
|
|
})
|
|
|
|
t.Run("MatchesProvisionerIDs", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
matchingDaemon0 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon0",
|
|
OrganizationID: org.ID,
|
|
})
|
|
matchingDaemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "matching-daemon1",
|
|
OrganizationID: org.ID,
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "non-matching-daemon",
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
IDs: []uuid.UUID{matchingDaemon0.ID, matchingDaemon1.ID},
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
if daemons[0].ProvisionerDaemon.ID != matchingDaemon0.ID {
|
|
daemons[0], daemons[1] = daemons[1], daemons[0]
|
|
}
|
|
require.Equal(t, matchingDaemon0.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, matchingDaemon1.ID, daemons[1].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("MatchesTags", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
fooDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"foo": "bar",
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "baz-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"baz": "qux",
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{"foo": "bar"},
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
require.Equal(t, fooDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
})
|
|
|
|
t.Run("UsesStaleInterval", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
daemon1 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "stale-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
daemon2 := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "idle-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 2)
|
|
|
|
if daemons[0].ProvisionerDaemon.ID != daemon1.ID {
|
|
daemons[0], daemons[1] = daemons[1], daemons[0]
|
|
}
|
|
require.Equal(t, daemon1.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, daemon2.ID, daemons[1].ProvisionerDaemon.ID)
|
|
require.Equal(t, database.ProvisionerDaemonStatusOffline, daemons[0].Status)
|
|
require.Equal(t, database.ProvisionerDaemonStatusIdle, daemons[1].Status)
|
|
})
|
|
|
|
t.Run("ExcludeOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
fooDaemon := dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 1)
|
|
|
|
require.Equal(t, fooDaemon.ID, daemons[0].ProvisionerDaemon.ID)
|
|
require.Equal(t, database.ProvisionerDaemonStatusIdle, daemons[0].Status)
|
|
})
|
|
|
|
t.Run("IncludeOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{
|
|
"foo": "bar",
|
|
},
|
|
})
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "bar-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Offline: sql.NullBool{Bool: true, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, 3)
|
|
|
|
statusCounts := make(map[database.ProvisionerDaemonStatus]int)
|
|
for _, daemon := range daemons {
|
|
statusCounts[daemon.Status]++
|
|
}
|
|
|
|
require.Equal(t, 2, statusCounts[database.ProvisionerDaemonStatusIdle])
|
|
require.Equal(t, 1, statusCounts[database.ProvisionerDaemonStatusOffline])
|
|
})
|
|
|
|
t.Run("MatchesStatuses", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "offline-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-time.Hour),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-time.Hour),
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(30 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(30 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
type testCase struct {
|
|
name string
|
|
statuses []database.ProvisionerDaemonStatus
|
|
expectedNum int
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "Get idle and offline",
|
|
statuses: []database.ProvisionerDaemonStatus{
|
|
database.ProvisionerDaemonStatusOffline,
|
|
database.ProvisionerDaemonStatusIdle,
|
|
},
|
|
expectedNum: 2,
|
|
},
|
|
{
|
|
name: "Get offline",
|
|
statuses: []database.ProvisionerDaemonStatus{
|
|
database.ProvisionerDaemonStatusOffline,
|
|
},
|
|
expectedNum: 1,
|
|
},
|
|
// Offline daemons should not be included without Offline param
|
|
{
|
|
name: "Get idle - empty statuses",
|
|
statuses: []database.ProvisionerDaemonStatus{},
|
|
expectedNum: 1,
|
|
},
|
|
{
|
|
name: "Get idle - nil statuses",
|
|
statuses: nil,
|
|
expectedNum: 1,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
//nolint:tparallel,paralleltest
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 45 * time.Minute.Milliseconds(),
|
|
Statuses: tc.statuses,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, tc.expectedNum)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("FilterByMaxAge", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "foo-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(45 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(45 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "bar-daemon",
|
|
OrganizationID: org.ID,
|
|
CreatedAt: dbtime.Now().Add(-(25 * time.Minute)),
|
|
LastSeenAt: sql.NullTime{
|
|
Valid: true,
|
|
Time: dbtime.Now().Add(-(25 * time.Minute)),
|
|
},
|
|
})
|
|
|
|
type testCase struct {
|
|
name string
|
|
maxAge sql.NullInt64
|
|
expectedNum int
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "Max age 1 hour",
|
|
maxAge: sql.NullInt64{Int64: time.Hour.Milliseconds(), Valid: true},
|
|
expectedNum: 2,
|
|
},
|
|
{
|
|
name: "Max age 30 minutes",
|
|
maxAge: sql.NullInt64{Int64: (30 * time.Minute).Milliseconds(), Valid: true},
|
|
expectedNum: 1,
|
|
},
|
|
{
|
|
name: "Max age 15 minutes",
|
|
maxAge: sql.NullInt64{Int64: (15 * time.Minute).Milliseconds(), Valid: true},
|
|
expectedNum: 0,
|
|
},
|
|
{
|
|
name: "No max age",
|
|
maxAge: sql.NullInt64{Valid: false},
|
|
expectedNum: 2,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
//nolint:tparallel,paralleltest
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
daemons, err := db.GetProvisionerDaemonsWithStatusByOrganization(context.Background(), database.GetProvisionerDaemonsWithStatusByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
StaleIntervalMS: 60 * time.Minute.Milliseconds(),
|
|
MaxAgeMs: tc.maxAge,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, daemons, tc.expectedNum)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentUsageStats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
agentID1 := uuid.New()
|
|
agentID2 := uuid.New()
|
|
workspaceID1 := uuid.New()
|
|
workspaceID2 := uuid.New()
|
|
templateID1 := uuid.New()
|
|
templateID2 := uuid.New()
|
|
userID1 := uuid.New()
|
|
userID2 := uuid.New()
|
|
|
|
// Old workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
TxBytes: 2,
|
|
RxBytes: 2,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 4,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID1,
|
|
WorkspaceID: workspaceID1,
|
|
TemplateID: templateID1,
|
|
UserID: userID1,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 2 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
TxBytes: 4,
|
|
RxBytes: 8,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
TxBytes: 2,
|
|
RxBytes: 3,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 4,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID2,
|
|
WorkspaceID: workspaceID2,
|
|
TemplateID: templateID2,
|
|
UserID: userID2,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
reqTime := dbtime.Now().Add(-time.Hour)
|
|
stats, err := db.GetWorkspaceAgentUsageStats(ctx, reqTime)
|
|
require.NoError(t, err)
|
|
|
|
ws1Stats, ws2Stats := stats[0], stats[1]
|
|
if ws1Stats.WorkspaceID != workspaceID1 {
|
|
ws1Stats, ws2Stats = ws2Stats, ws1Stats
|
|
}
|
|
require.Equal(t, int64(3), ws1Stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(3), ws1Stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(1), ws1Stats.SessionCountVSCode)
|
|
require.Equal(t, int64(1), ws1Stats.SessionCountJetBrains)
|
|
require.Equal(t, int64(0), ws1Stats.SessionCountSSH)
|
|
require.Equal(t, int64(0), ws1Stats.SessionCountReconnectingPTY)
|
|
|
|
require.Equal(t, int64(6), ws2Stats.WorkspaceTxBytes)
|
|
require.Equal(t, int64(11), ws2Stats.WorkspaceRxBytes)
|
|
require.Equal(t, int64(1), ws2Stats.SessionCountSSH)
|
|
require.Equal(t, int64(1), ws2Stats.SessionCountJetBrains)
|
|
require.Equal(t, int64(0), ws2Stats.SessionCountVSCode)
|
|
require.Equal(t, int64(0), ws2Stats.SessionCountReconnectingPTY)
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
ctx := context.Background()
|
|
// Since the queries exclude the current minute
|
|
insertTime := dbtime.Now().Add(-time.Minute)
|
|
|
|
agentID := uuid.New()
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agentID,
|
|
TxBytes: 3,
|
|
RxBytes: 4,
|
|
ConnectionMedianLatencyMS: 2,
|
|
// Should be ignored
|
|
SessionCountSSH: 3,
|
|
SessionCountVSCode: 1,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStats(ctx, dbtime.Now().Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 1)
|
|
require.Equal(t, int64(3), stats[0].WorkspaceTxBytes)
|
|
require.Equal(t, int64(4), stats[0].WorkspaceRxBytes)
|
|
require.Equal(t, int64(0), stats[0].SessionCountVSCode)
|
|
require.Equal(t, int64(0), stats[0].SessionCountSSH)
|
|
require.Equal(t, int64(0), stats[0].SessionCountReconnectingPTY)
|
|
require.Equal(t, int64(0), stats[0].SessionCountJetBrains)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentUsageStatsAndLabels(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Aggregates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
insertTime := dbtime.Now()
|
|
|
|
// Insert user, agent, template, workspace
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job1 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource1.ID,
|
|
})
|
|
template1 := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
workspace1 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template1.ID,
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
job2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job2.ID,
|
|
})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource2.ID,
|
|
})
|
|
template2 := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
workspace2 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template2.ID,
|
|
})
|
|
|
|
// Old workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
TxBytes: 1,
|
|
RxBytes: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 1 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
TxBytes: 2,
|
|
RxBytes: 2,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 4,
|
|
SessionCountSSH: 3,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountJetBrains: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent1.ID,
|
|
WorkspaceID: workspace1.ID,
|
|
TemplateID: template1.ID,
|
|
UserID: user1.ID,
|
|
SessionCountReconnectingPTY: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
// Latest workspace 2 stats
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
TxBytes: 4,
|
|
RxBytes: 8,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
SessionCountVSCode: 1,
|
|
Usage: true,
|
|
})
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime,
|
|
AgentID: agent2.ID,
|
|
WorkspaceID: workspace2.ID,
|
|
TemplateID: template2.ID,
|
|
UserID: user2.ID,
|
|
SessionCountSSH: 1,
|
|
Usage: true,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStatsAndLabels(ctx, insertTime.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 2)
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user1.Username,
|
|
AgentName: agent1.Name,
|
|
WorkspaceName: workspace1.Name,
|
|
TxBytes: 3,
|
|
RxBytes: 3,
|
|
SessionCountJetBrains: 1,
|
|
SessionCountReconnectingPTY: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user2.Username,
|
|
AgentName: agent2.Name,
|
|
WorkspaceName: workspace2.Name,
|
|
RxBytes: 8,
|
|
TxBytes: 4,
|
|
SessionCountVSCode: 1,
|
|
SessionCountSSH: 1,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
})
|
|
|
|
t.Run("NoUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
insertTime := dbtime.Now()
|
|
// Insert user, agent, template, workspace
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
|
|
dbgen.WorkspaceAgentStat(t, db, database.WorkspaceAgentStat{
|
|
CreatedAt: insertTime.Add(-time.Minute),
|
|
AgentID: agent.ID,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateID: template.ID,
|
|
UserID: user.ID,
|
|
RxBytes: 4,
|
|
TxBytes: 5,
|
|
ConnectionMedianLatencyMS: 1,
|
|
// Should be ignored
|
|
SessionCountVSCode: 3,
|
|
SessionCountSSH: 1,
|
|
})
|
|
|
|
stats, err := db.GetWorkspaceAgentUsageStatsAndLabels(ctx, insertTime.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, stats, 1)
|
|
require.Contains(t, stats, database.GetWorkspaceAgentUsageStatsAndLabelsRow{
|
|
Username: user.Username,
|
|
AgentName: agent.Name,
|
|
WorkspaceName: workspace.Name,
|
|
RxBytes: 4,
|
|
TxBytes: 5,
|
|
ConnectionMedianLatencyMS: 1,
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestGetAuthorizedWorkspacesAndAgentsByOwnerID(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{
|
|
RBACRoles: []string{rbac.RoleOwner().String()},
|
|
})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
|
|
pendingID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusPending,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: pendingID,
|
|
CreateAgent: true,
|
|
})
|
|
failedID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusFailed,
|
|
CreateWorkspace: true,
|
|
CreateAgent: true,
|
|
WorkspaceID: failedID,
|
|
})
|
|
succeededID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
WorkspaceTransition: database.WorkspaceTransitionStart,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: succeededID,
|
|
CreateAgent: true,
|
|
ExtraAgents: 1,
|
|
ExtraBuilds: 2,
|
|
})
|
|
deletedID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
WorkspaceTransition: database.WorkspaceTransitionDelete,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: deletedID,
|
|
CreateAgent: false,
|
|
})
|
|
|
|
ownerCheckFn := func(ownerRows []database.GetWorkspacesAndAgentsByOwnerIDRow) {
|
|
require.Len(t, ownerRows, 4)
|
|
for _, row := range ownerRows {
|
|
switch row.ID {
|
|
case pendingID:
|
|
require.Len(t, row.Agents, 1)
|
|
require.Equal(t, database.ProvisionerJobStatusPending, row.JobStatus)
|
|
case failedID:
|
|
require.Len(t, row.Agents, 1)
|
|
require.Equal(t, database.ProvisionerJobStatusFailed, row.JobStatus)
|
|
case succeededID:
|
|
require.Len(t, row.Agents, 2)
|
|
require.Equal(t, database.ProvisionerJobStatusSucceeded, row.JobStatus)
|
|
require.Equal(t, database.WorkspaceTransitionStart, row.Transition)
|
|
case deletedID:
|
|
require.Len(t, row.Agents, 0)
|
|
require.Equal(t, database.ProvisionerJobStatusSucceeded, row.JobStatus)
|
|
require.Equal(t, database.WorkspaceTransitionDelete, row.Transition)
|
|
default:
|
|
t.Fatalf("unexpected workspace ID: %s", row.ID)
|
|
}
|
|
}
|
|
}
|
|
t.Run("sqlQuerier", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
userSubject, _, err := httpmw.UserRBACSubject(ctx, db, user.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedUser, err := authorizer.Prepare(ctx, userSubject, policy.ActionRead, rbac.ResourceWorkspace.Type)
|
|
require.NoError(t, err)
|
|
userCtx := dbauthz.As(ctx, userSubject)
|
|
userRows, err := db.GetAuthorizedWorkspacesAndAgentsByOwnerID(userCtx, owner.ID, preparedUser)
|
|
require.NoError(t, err)
|
|
require.Len(t, userRows, 0)
|
|
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceWorkspace.Type)
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
ownerRows, err := db.GetAuthorizedWorkspacesAndAgentsByOwnerID(ownerCtx, owner.ID, preparedOwner)
|
|
require.NoError(t, err)
|
|
ownerCheckFn(ownerRows)
|
|
})
|
|
|
|
t.Run("dbauthz", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
userSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, user.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
userCtx := dbauthz.As(ctx, userSubject)
|
|
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
|
|
userRows, err := authzdb.GetWorkspacesAndAgentsByOwnerID(userCtx, owner.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, userRows, 0)
|
|
|
|
ownerRows, err := authzdb.GetWorkspacesAndAgentsByOwnerID(ownerCtx, owner.ID)
|
|
require.NoError(t, err)
|
|
ownerCheckFn(ownerRows)
|
|
})
|
|
}
|
|
|
|
func TestGetAuthorizedChats(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
|
|
// Create users with different roles.
|
|
owner := dbgen.User(t, db, database.User{
|
|
RBACRoles: []string{rbac.RoleOwner().String()},
|
|
})
|
|
member := dbgen.User(t, db, database.User{})
|
|
secondMember := dbgen.User(t, db, database.User{})
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: member.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: secondMember.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
|
|
// Create FK dependencies: a chat provider and model config.
|
|
_ = dbgen.ChatProvider(t, db, database.ChatProvider{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
})
|
|
modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
IsDefault: true,
|
|
CompressionThreshold: 80,
|
|
})
|
|
|
|
// Create 3 chats owned by owner.
|
|
for i := range 3 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("owner chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
// Create 2 chats owned by member.
|
|
for i := range 2 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: member.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("member chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
t.Run("sqlQuerier", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Member should only see their own 2 chats.
|
|
memberSubject, _, err := httpmw.UserRBACSubject(ctx, db, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedMember, err := authorizer.Prepare(ctx, memberSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
memberRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberRows, 2)
|
|
for _, row := range memberRows {
|
|
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
|
}
|
|
|
|
// Owner should see at least the 5 pre-created chats (site-wide
|
|
// access). Parallel subtests may add more, so use GreaterOrEqual.
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, db, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOwner, err := authorizer.Prepare(ctx, ownerSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
ownerRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOwner)
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(ownerRows), 5)
|
|
|
|
// secondMember has no chats and should see 0.
|
|
secondSubject, _, err := httpmw.UserRBACSubject(ctx, db, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedSecond, err := authorizer.Prepare(ctx, secondSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
secondRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSecond)
|
|
require.NoError(t, err)
|
|
require.Len(t, secondRows, 0)
|
|
|
|
// Org admin should NOT see other users' chats when they are
|
|
// in a different org than the chat owner.
|
|
orgs, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, orgs)
|
|
orgAdmin := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: orgAdmin.ID,
|
|
OrganizationID: orgs[0].ID,
|
|
Roles: []string{rbac.RoleOrgAdmin()},
|
|
})
|
|
orgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, orgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedOrgAdmin, err := authorizer.Prepare(ctx, orgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
orgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedOrgAdmin)
|
|
require.NoError(t, err)
|
|
require.Len(t, orgAdminRows, 0, "org admin with no chats should see 0 chats")
|
|
|
|
// Org admin in SAME org should see all chats in that org.
|
|
sameOrgAdmin := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: sameOrgAdmin.ID,
|
|
OrganizationID: org.ID,
|
|
Roles: []string{rbac.RoleOrgAdmin()},
|
|
})
|
|
sameOrgAdminSubject, _, err := httpmw.UserRBACSubject(ctx, db, sameOrgAdmin.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedSameOrgAdmin, err := authorizer.Prepare(ctx, sameOrgAdminSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
sameOrgAdminRows, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{}, preparedSameOrgAdmin)
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(sameOrgAdminRows), 5, "same-org admin should see all chats in their org")
|
|
|
|
// OwnerID filter: member queries their own chats.
|
|
memberFilterSelf, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
OwnerID: member.ID,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberFilterSelf, 2)
|
|
|
|
// OwnerID filter: member queries owner's chats → sees 0.
|
|
memberFilterOwner, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, memberFilterOwner, 0)
|
|
})
|
|
|
|
t.Run("dbauthz", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authzdb := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
// As member: should see only own 2 chats.
|
|
memberSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, member.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
memberCtx := dbauthz.As(ctx, memberSubject)
|
|
memberRows, err := authzdb.GetChats(memberCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, memberRows, 2)
|
|
for _, row := range memberRows {
|
|
require.Equal(t, member.ID, row.Chat.OwnerID, "member should only see own chats")
|
|
}
|
|
|
|
// As owner: should see at least the 5 pre-created chats.
|
|
ownerSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, owner.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
ownerCtx := dbauthz.As(ctx, ownerSubject)
|
|
ownerRows, err := authzdb.GetChats(ownerCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(ownerRows), 5)
|
|
|
|
// As secondMember: should see 0 chats.
|
|
secondSubject, _, err := httpmw.UserRBACSubject(ctx, authzdb, secondMember.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
secondCtx := dbauthz.As(ctx, secondSubject)
|
|
secondRows, err := authzdb.GetChats(secondCtx, database.GetChatsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, secondRows, 0)
|
|
})
|
|
|
|
t.Run("pagination", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Use a dedicated user for pagination to avoid interference
|
|
// with the other parallel subtests.
|
|
paginationUser := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: paginationUser.ID, OrganizationID: org.ID, Roles: []string{rbac.RoleAgentsAccess()}})
|
|
for i := range 7 {
|
|
dbgen.Chat(t, db, database.Chat{
|
|
OrganizationID: org.ID,
|
|
OwnerID: paginationUser.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: fmt.Sprintf("pagination chat %d", i+1),
|
|
})
|
|
}
|
|
|
|
pagUserSubject, _, err := httpmw.UserRBACSubject(ctx, db, paginationUser.ID, rbac.ExpandableScope(rbac.ScopeAll))
|
|
require.NoError(t, err)
|
|
preparedMember, err := authorizer.Prepare(ctx, pagUserSubject, policy.ActionRead, rbac.ResourceChat.Type)
|
|
require.NoError(t, err)
|
|
|
|
// Fetch first page with limit=2.
|
|
page1, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
LimitOpt: 2,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
require.Len(t, page1, 2)
|
|
for _, row := range page1 {
|
|
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
|
}
|
|
|
|
// Fetch remaining pages and collect all chat IDs.
|
|
allIDs := make(map[uuid.UUID]struct{})
|
|
for _, row := range page1 {
|
|
allIDs[row.Chat.ID] = struct{}{}
|
|
}
|
|
offset := int32(2)
|
|
for {
|
|
page, err := db.GetAuthorizedChats(ctx, database.GetChatsParams{
|
|
LimitOpt: 2,
|
|
OffsetOpt: offset,
|
|
}, preparedMember)
|
|
require.NoError(t, err)
|
|
for _, row := range page {
|
|
require.Equal(t, paginationUser.ID, row.Chat.OwnerID, "paginated results must belong to pagination user")
|
|
allIDs[row.Chat.ID] = struct{}{}
|
|
}
|
|
if len(page) < 2 {
|
|
break
|
|
}
|
|
offset += int32(len(page)) //nolint:gosec // Test code, pagination values are small.
|
|
}
|
|
|
|
// All 7 member chats should be accounted for with no leakage.
|
|
require.Len(t, allIDs, 7, "pagination should return all member chats exactly once")
|
|
})
|
|
}
|
|
|
|
func TestInsertWorkspaceAgentLogs(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
ctx := context.Background()
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
source := dbgen.WorkspaceAgentLogSource(t, db, database.WorkspaceAgentLogSource{
|
|
WorkspaceAgentID: agent.ID,
|
|
})
|
|
logs, err := db.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
|
|
AgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
Output: []string{"first"},
|
|
Level: []database.LogLevel{database.LogLevelInfo},
|
|
LogSourceID: source.ID,
|
|
// 1 MB is the max
|
|
OutputLength: 1 << 20,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), logs[0].ID)
|
|
|
|
_, err = db.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
|
|
AgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
Output: []string{"second"},
|
|
Level: []database.LogLevel{database.LogLevelInfo},
|
|
LogSourceID: source.ID,
|
|
OutputLength: 1,
|
|
})
|
|
require.True(t, database.IsWorkspaceAgentLogsLimitError(err))
|
|
}
|
|
|
|
func TestProxyByHostname(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
// Insert a bunch of different proxies.
|
|
proxies := []struct {
|
|
name string
|
|
accessURL string
|
|
wildcardHostname string
|
|
}{
|
|
{
|
|
name: "one",
|
|
accessURL: "https://one.coder.com",
|
|
wildcardHostname: "*.wildcard.one.coder.com",
|
|
},
|
|
{
|
|
name: "two",
|
|
accessURL: "https://two.coder.com",
|
|
wildcardHostname: "*--suffix.two.coder.com",
|
|
},
|
|
}
|
|
for _, p := range proxies {
|
|
dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{
|
|
Name: p.name,
|
|
Url: p.accessURL,
|
|
WildcardHostname: p.wildcardHostname,
|
|
})
|
|
}
|
|
|
|
cases := []struct {
|
|
name string
|
|
testHostname string
|
|
allowAccessURL bool
|
|
allowWildcardHost bool
|
|
matchProxyName string
|
|
}{
|
|
{
|
|
name: "NoMatch",
|
|
testHostname: "test.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "MatchAccessURL",
|
|
testHostname: "one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "one",
|
|
},
|
|
{
|
|
name: "MatchWildcard",
|
|
testHostname: "something.wildcard.one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "one",
|
|
},
|
|
{
|
|
name: "MatchSuffix",
|
|
testHostname: "something--suffix.two.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "two",
|
|
},
|
|
{
|
|
name: "ValidateHostname/1",
|
|
testHostname: ".*ne.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "ValidateHostname/2",
|
|
testHostname: "https://one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "ValidateHostname/3",
|
|
testHostname: "one.coder.com:8080/hello",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "IgnoreAccessURLMatch",
|
|
testHostname: "one.coder.com",
|
|
allowAccessURL: false,
|
|
allowWildcardHost: true,
|
|
matchProxyName: "",
|
|
},
|
|
{
|
|
name: "IgnoreWildcardMatch",
|
|
testHostname: "hi.wildcard.one.coder.com",
|
|
allowAccessURL: true,
|
|
allowWildcardHost: false,
|
|
matchProxyName: "",
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
proxy, err := db.GetWorkspaceProxyByHostname(context.Background(), database.GetWorkspaceProxyByHostnameParams{
|
|
Hostname: c.testHostname,
|
|
AllowAccessUrl: c.allowAccessURL,
|
|
AllowWildcardHostname: c.allowWildcardHost,
|
|
})
|
|
if c.matchProxyName == "" {
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
require.Empty(t, proxy)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, proxy)
|
|
require.Equal(t, c.matchProxyName, proxy.Name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultProxy(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
depID := uuid.NewString()
|
|
err = db.InsertDeploymentID(ctx, depID)
|
|
require.NoError(t, err, "insert deployment id")
|
|
|
|
// Fetch empty proxy values
|
|
defProxy, err := db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
|
|
require.Equal(t, defProxy.DisplayName, "Default")
|
|
require.Equal(t, defProxy.IconURL, "/emojis/1f3e1.png")
|
|
|
|
// Set the proxy values
|
|
args := database.UpsertDefaultProxyParams{
|
|
DisplayName: "displayname",
|
|
IconURL: "/icon.png",
|
|
}
|
|
err = db.UpsertDefaultProxy(ctx, args)
|
|
require.NoError(t, err, "insert def proxy")
|
|
|
|
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
|
require.Equal(t, defProxy.IconURL, args.IconURL)
|
|
|
|
// Upsert values
|
|
args = database.UpsertDefaultProxyParams{
|
|
DisplayName: "newdisplayname",
|
|
IconURL: "/newicon.png",
|
|
}
|
|
err = db.UpsertDefaultProxy(ctx, args)
|
|
require.NoError(t, err, "upsert def proxy")
|
|
|
|
defProxy, err = db.GetDefaultProxyConfig(ctx)
|
|
require.NoError(t, err, "get def proxy")
|
|
require.Equal(t, defProxy.DisplayName, args.DisplayName)
|
|
require.Equal(t, defProxy.IconURL, args.IconURL)
|
|
|
|
// Ensure other site configs are the same
|
|
found, err := db.GetDeploymentID(ctx)
|
|
require.NoError(t, err, "get deployment id")
|
|
require.Equal(t, depID, found)
|
|
}
|
|
|
|
func TestQueuePosition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
jobCount := 10
|
|
jobs := []database.ProvisionerJob{}
|
|
jobIDs := []uuid.UUID{}
|
|
for i := 0; i < jobCount; i++ {
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Tags: database.StringMap{},
|
|
})
|
|
jobs = append(jobs, job)
|
|
jobIDs = append(jobIDs, job.ID)
|
|
|
|
// We need a slight amount of time between each insertion to ensure that
|
|
// the queue position is correct... it's sorted by `created_at`.
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
// Ensure the `tags` field is NOT NULL for the default provisioner;
|
|
// otherwise, it won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, jobCount)
|
|
sort.Slice(queued, func(i, j int) bool {
|
|
return queued[i].QueuePosition < queued[j].QueuePosition
|
|
})
|
|
// Ensure that the queue positions are correct based on insertion ID!
|
|
for index, job := range queued {
|
|
require.Equal(t, job.QueuePosition, int64(index+1))
|
|
require.Equal(t, job.ProvisionerJob.ID, jobs[index].ID)
|
|
}
|
|
|
|
job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
},
|
|
Types: database.AllProvisionerTypeValues(),
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
ProvisionerTags: json.RawMessage("{}"),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, jobs[0].ID, job.ID)
|
|
|
|
queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, queued, jobCount)
|
|
sort.Slice(queued, func(i, j int) bool {
|
|
return queued[i].QueuePosition < queued[j].QueuePosition
|
|
})
|
|
// Ensure that queue positions are updated now that the first job has been acquired!
|
|
for index, job := range queued {
|
|
if index == 0 {
|
|
require.Equal(t, job.QueuePosition, int64(0))
|
|
continue
|
|
}
|
|
require.Equal(t, job.QueuePosition, int64(index))
|
|
require.Equal(t, job.ProvisionerJob.ID, jobs[index].ID)
|
|
}
|
|
}
|
|
|
|
func TestAcquireProvisionerJob(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("HumanInitiatedJobsFirst", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
ctx = testutil.Context(t, testutil.WaitMedium)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
_ = dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{}) // Required for queue position
|
|
now = dbtime.Now()
|
|
numJobs = 10
|
|
humanIDs = make([]uuid.UUID, 0, numJobs/2)
|
|
prebuildIDs = make([]uuid.UUID, 0, numJobs/2)
|
|
)
|
|
|
|
// Given: a number of jobs in the queue, with prebuilds and non-prebuilds interleaved
|
|
for idx := range numJobs {
|
|
var initiator uuid.UUID
|
|
if idx%2 == 0 {
|
|
initiator = database.PrebuildsSystemUserID
|
|
} else {
|
|
initiator = uuid.MustParse("c0dec0de-c0de-c0de-c0de-c0dec0dec0de")
|
|
}
|
|
pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.MustParse(fmt.Sprintf("00000000-0000-0000-0000-00000000000%x", idx+1)),
|
|
CreatedAt: time.Now().Add(-time.Second * time.Duration(idx)),
|
|
UpdatedAt: time.Now().Add(-time.Second * time.Duration(idx)),
|
|
InitiatorID: initiator,
|
|
OrganizationID: org.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: uuid.New(),
|
|
Input: json.RawMessage(`{}`),
|
|
Tags: database.StringMap{},
|
|
TraceMetadata: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
// We expected prebuilds to be acquired after human-initiated jobs.
|
|
if initiator == database.PrebuildsSystemUserID {
|
|
prebuildIDs = append([]uuid.UUID{pj.ID}, prebuildIDs...)
|
|
} else {
|
|
humanIDs = append([]uuid.UUID{pj.ID}, humanIDs...)
|
|
}
|
|
t.Logf("created job id=%q initiator=%q created_at=%q", pj.ID.String(), pj.InitiatorID.String(), pj.CreatedAt.String())
|
|
}
|
|
|
|
expectedIDs := append(humanIDs, prebuildIDs...) //nolint:gocritic // not the same slice
|
|
|
|
// When: we query the queue positions for the jobs
|
|
qjs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: expectedIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, qjs, numJobs)
|
|
// Ensure the jobs are sorted by queue position.
|
|
sort.Slice(qjs, func(i, j int) bool {
|
|
return qjs[i].QueuePosition < qjs[j].QueuePosition
|
|
})
|
|
|
|
// Then: the queue positions for the jobs should indicate the order in which
|
|
// they will be acquired, with human-initiated jobs first.
|
|
for idx, qj := range qjs {
|
|
t.Logf("queued job %d/%d id=%q initiator=%q created_at=%q queue_position=%d", idx+1, numJobs, qj.ProvisionerJob.ID.String(), qj.ProvisionerJob.InitiatorID.String(), qj.ProvisionerJob.CreatedAt.String(), qj.QueuePosition)
|
|
require.Equal(t, expectedIDs[idx].String(), qj.ProvisionerJob.ID.String(), "job %d/%d should match expected id", idx+1, numJobs)
|
|
require.Equal(t, int64(idx+1), qj.QueuePosition, "job %d/%d should have queue position %d", idx+1, numJobs, idx+1)
|
|
}
|
|
|
|
// When: the jobs are acquired
|
|
// Then: human-initiated jobs are prioritized first.
|
|
for idx := range numJobs {
|
|
acquired, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
ProvisionerTags: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, expectedIDs[idx].String(), acquired.ID.String(), "acquired job %d/%d with initiator %q", idx+1, numJobs, acquired.InitiatorID.String())
|
|
t.Logf("acquired job id=%q initiator=%q created_at=%q", acquired.ID.String(), acquired.InitiatorID.String(), acquired.CreatedAt.String())
|
|
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
|
ID: acquired.ID,
|
|
UpdatedAt: now,
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
Error: sql.NullString{},
|
|
ErrorCode: sql.NullString{},
|
|
})
|
|
require.NoError(t, err, "mark job %d/%d as complete", idx+1, numJobs)
|
|
}
|
|
})
|
|
|
|
t.Run("SkipsCanceledPendingJobs", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
ctx = testutil.Context(t, testutil.WaitMedium)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
now = dbtime.Now()
|
|
)
|
|
|
|
// Insert a pending job (started_at is NULL).
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
InitiatorID: uuid.New(),
|
|
OrganizationID: org.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: uuid.New(),
|
|
Input: json.RawMessage(`{}`),
|
|
Tags: database.StringMap{},
|
|
TraceMetadata: pqtype.NullRawMessage{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Cancel it while still pending. In production (workspacebuilds.go), canceling
|
|
// a pending build sets completed_at but leaves started_at NULL since no
|
|
// provisioner ever started the job.
|
|
err = db.UpdateProvisionerJobWithCancelByID(ctx, database.UpdateProvisionerJobWithCancelByIDParams{
|
|
ID: job.ID,
|
|
CanceledAt: sql.NullTime{Time: now, Valid: true},
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// AcquireProvisionerJob should skip this job since it's already completed.
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: org.ID,
|
|
StartedAt: sql.NullTime{Time: now, Valid: true},
|
|
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
ProvisionerTags: json.RawMessage(`{}`),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
})
|
|
}
|
|
|
|
func TestUserLastSeenFilter(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
t.Run("Before", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
now := dbtime.Now()
|
|
|
|
yesterday := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now.Add(time.Hour * -25),
|
|
})
|
|
today := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now,
|
|
})
|
|
lastWeek := dbgen.User(t, db, database.User{
|
|
LastSeenAt: now.Add((time.Hour * -24 * 7) + (-1 * time.Hour)),
|
|
})
|
|
|
|
beforeToday, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour * -24),
|
|
})
|
|
require.NoError(t, err)
|
|
database.ConvertUserRows(beforeToday)
|
|
|
|
requireUsersMatch(t, []database.User{yesterday, lastWeek}, beforeToday, "before today")
|
|
|
|
justYesterday, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour * -24),
|
|
LastSeenAfter: now.Add(time.Hour * -24 * 2),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{yesterday}, justYesterday, "just yesterday")
|
|
|
|
all, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenBefore: now.Add(time.Hour),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{today, yesterday, lastWeek}, all, "all")
|
|
|
|
allAfterLastWeek, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
LastSeenAfter: now.Add(time.Hour * -24 * 7),
|
|
})
|
|
require.NoError(t, err)
|
|
requireUsersMatch(t, []database.User{today, yesterday}, allAfterLastWeek, "after last week")
|
|
})
|
|
}
|
|
|
|
func TestGetUsers_IncludeSystem(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
includeSystem bool
|
|
wantSystemUser bool
|
|
}{
|
|
{
|
|
name: "include system users",
|
|
includeSystem: true,
|
|
wantSystemUser: true,
|
|
},
|
|
{
|
|
name: "exclude system users",
|
|
includeSystem: false,
|
|
wantSystemUser: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Given: a system user
|
|
// postgres: introduced by migration coderd/database/migrations/00030*_system_user.up.sql
|
|
db, _ := dbtestutil.NewDB(t)
|
|
other := dbgen.User(t, db, database.User{})
|
|
users, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
IncludeSystem: tt.includeSystem,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Should always find the regular user
|
|
foundRegularUser := false
|
|
foundSystemUser := false
|
|
|
|
for _, u := range users {
|
|
if u.IsSystem {
|
|
foundSystemUser = true
|
|
require.Equal(t, database.PrebuildsSystemUserID, u.ID)
|
|
} else {
|
|
foundRegularUser = true
|
|
require.Equalf(t, other.ID.String(), u.ID.String(), "found unexpected regular user")
|
|
}
|
|
}
|
|
|
|
require.True(t, foundRegularUser, "regular user should always be found")
|
|
require.Equal(t, tt.wantSystemUser, foundSystemUser, "system user presence should match includeSystem setting")
|
|
require.Equal(t, tt.wantSystemUser, len(users) == 2, "should have 2 users when including system user, 1 otherwise")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateSystemUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// TODO (sasswart): We've disabled the protection that prevents updates to system users
|
|
// while we reassess the mechanism to do so. Rather than skip the test, we've just inverted
|
|
// the assertions to ensure that the behavior is as desired.
|
|
// Once we've re-enabeld the system user protection, we'll revert the assertions.
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Given: a system user introduced by migration coderd/database/migrations/00030*_system_user.up.sql
|
|
db, _ := dbtestutil.NewDB(t)
|
|
users, err := db.GetUsers(ctx, database.GetUsersParams{
|
|
IncludeSystem: true,
|
|
})
|
|
require.NoError(t, err)
|
|
var systemUser database.GetUsersRow
|
|
for _, u := range users {
|
|
if u.IsSystem {
|
|
systemUser = u
|
|
}
|
|
}
|
|
require.NotNil(t, systemUser)
|
|
|
|
// When: attempting to update a system user's name.
|
|
_, err = db.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
|
|
ID: systemUser.ID,
|
|
Email: systemUser.Email,
|
|
Username: systemUser.Username,
|
|
AvatarURL: systemUser.AvatarURL,
|
|
Name: "not prebuilds",
|
|
})
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
|
|
// When: attempting to delete a system user.
|
|
err = db.UpdateUserDeletedByID(ctx, systemUser.ID)
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
|
|
// When: attempting to update a user's roles.
|
|
_, err = db.UpdateUserRoles(ctx, database.UpdateUserRolesParams{
|
|
ID: systemUser.ID,
|
|
GrantedRoles: []string{rbac.RoleAuditor().String()},
|
|
})
|
|
// Then: the attempt is rejected by a postgres trigger.
|
|
// require.ErrorContains(t, err, "Cannot modify or delete system users")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestInsertUserServiceAccountConstraints(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Happy path: should succeed.
|
|
t.Run("ServiceAccountWithEmptyEmailAndLoginNone", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
user, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypeNone,
|
|
ID: uuid.New(),
|
|
Username: "sa-ok",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, user.IsServiceAccount)
|
|
require.Empty(t, user.Email)
|
|
})
|
|
|
|
// Service account with a non-empty email should be rejected
|
|
// by the users_email_not_empty constraint.
|
|
t.Run("ServiceAccountWithNonEmptyEmail", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "sa@coder.com",
|
|
LoginType: database.LoginTypeNone,
|
|
ID: uuid.New(),
|
|
Username: "sa-with-email",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersEmailNotEmpty))
|
|
})
|
|
|
|
// A non-service-account with empty email should be rejected
|
|
// by the users_email_not_empty constraint.
|
|
t.Run("RegularUserWithEmptyEmail", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypePassword,
|
|
ID: uuid.New(),
|
|
Username: "regular-no-email",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: false,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersEmailNotEmpty))
|
|
})
|
|
|
|
// Service account with login_type!=none should be rejected
|
|
// by the users_service_account_login_type constraint.
|
|
t.Run("ServiceAccountWithPasswordLoginType", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
|
Email: "",
|
|
LoginType: database.LoginTypePassword,
|
|
ID: uuid.New(),
|
|
Username: "sa-with-password",
|
|
RBACRoles: []string{},
|
|
IsServiceAccount: true,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUsersServiceAccountLoginType))
|
|
})
|
|
}
|
|
|
|
func TestUserChangeLoginType(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
alice := dbgen.User(t, db, database.User{
|
|
LoginType: database.LoginTypePassword,
|
|
})
|
|
bob := dbgen.User(t, db, database.User{
|
|
LoginType: database.LoginTypePassword,
|
|
})
|
|
bobExpPass := bob.HashedPassword
|
|
require.NotEmpty(t, alice.HashedPassword, "hashed password should not start empty")
|
|
require.NotEmpty(t, bob.HashedPassword, "hashed password should not start empty")
|
|
|
|
alice, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
|
|
NewLoginType: database.LoginTypeOIDC,
|
|
UserID: alice.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Empty(t, alice.HashedPassword, "hashed password should be empty")
|
|
|
|
// First check other users are not affected
|
|
bob, err = db.GetUserByID(ctx, bob.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
|
|
|
|
// Then check password -> password is a noop
|
|
bob, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
|
|
NewLoginType: database.LoginTypePassword,
|
|
UserID: bob.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
bob, err = db.GetUserByID(ctx, bob.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
|
|
}
|
|
|
|
func TestDefaultOrg(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
// Should start with the default org
|
|
all, err := db.GetOrganizations(ctx, database.GetOrganizationsParams{})
|
|
require.NoError(t, err)
|
|
require.Len(t, all, 1)
|
|
require.True(t, all[0].IsDefault, "first org should always be default")
|
|
}
|
|
|
|
func TestAuditLogDefaultLimit(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
for i := 0; i < 110; i++ {
|
|
dbgen.AuditLog(t, db, database.AuditLog{})
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
rows, err := db.GetAuditLogsOffset(ctx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// The length should match the default limit of the SQL query.
|
|
// Updating the sql query requires changing the number below to match.
|
|
require.Len(t, rows, 100)
|
|
}
|
|
|
|
func TestAuditLogCount(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
dbgen.AuditLog(t, db, database.AuditLog{})
|
|
|
|
count, err := db.CountAuditLogs(ctx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), count)
|
|
}
|
|
|
|
func TestWorkspaceQuotas(t *testing.T) {
|
|
t.Parallel()
|
|
orgMemberIDs := func(o database.OrganizationMember) uuid.UUID {
|
|
return o.UserID
|
|
}
|
|
groupMemberIDs := func(m database.GroupMember) uuid.UUID {
|
|
return m.UserID
|
|
}
|
|
|
|
t.Run("CorruptedEveryone", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
// Create an extra org as a distraction
|
|
distract := dbgen.Organization(t, db, database.Organization{})
|
|
_, err := db.InsertAllUsersGroup(ctx, distract.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
|
|
QuotaAllowance: 15,
|
|
ID: distract.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Create an org with 2 users
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
everyoneGroup, err := db.InsertAllUsersGroup(ctx, org.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Add a quota to the everyone group
|
|
_, err = db.UpdateGroupByID(ctx, database.UpdateGroupByIDParams{
|
|
QuotaAllowance: 50,
|
|
ID: everyoneGroup.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Add people to the org
|
|
one := dbgen.User(t, db, database.User{})
|
|
two := dbgen.User(t, db, database.User{})
|
|
memOne := dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: one.ID,
|
|
})
|
|
memTwo := dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: two.ID,
|
|
})
|
|
|
|
// Fetch the 'Everyone' group members
|
|
everyoneMembers, err := db.GetGroupMembersByGroupID(ctx, database.GetGroupMembersByGroupIDParams{
|
|
GroupID: everyoneGroup.ID,
|
|
IncludeSystem: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.ElementsMatch(t, slice.List(everyoneMembers, groupMemberIDs),
|
|
slice.List([]database.OrganizationMember{memOne, memTwo}, orgMemberIDs))
|
|
|
|
// Check the quota is correct.
|
|
allowance, err := db.GetQuotaAllowanceForUser(ctx, database.GetQuotaAllowanceForUserParams{
|
|
UserID: one.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(50), allowance)
|
|
|
|
// Now try to corrupt the DB
|
|
// Insert rows into the everyone group
|
|
err = db.InsertGroupMember(ctx, database.InsertGroupMemberParams{
|
|
UserID: memOne.UserID,
|
|
GroupID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Ensure allowance remains the same
|
|
allowance, err = db.GetQuotaAllowanceForUser(ctx, database.GetQuotaAllowanceForUserParams{
|
|
UserID: one.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(50), allowance)
|
|
})
|
|
}
|
|
|
|
// TestReadCustomRoles tests the input params returns the correct set of roles.
|
|
func TestReadCustomRoles(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// Make a few site roles, and a few org roles
|
|
orgIDs := make([]uuid.UUID, 3)
|
|
for i := range orgIDs {
|
|
orgIDs[i] = uuid.New()
|
|
}
|
|
|
|
allRoles := make([]database.CustomRole, 0)
|
|
siteRoles := make([]database.CustomRole, 0)
|
|
orgRoles := make([]database.CustomRole, 0)
|
|
for i := 0; i < 15; i++ {
|
|
orgID := uuid.NullUUID{
|
|
UUID: orgIDs[i%len(orgIDs)],
|
|
Valid: true,
|
|
}
|
|
if i%4 == 0 {
|
|
// Some should be site wide
|
|
orgID = uuid.NullUUID{}
|
|
}
|
|
|
|
role, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: fmt.Sprintf("role-%d", i),
|
|
OrganizationID: orgID,
|
|
})
|
|
require.NoError(t, err)
|
|
allRoles = append(allRoles, role)
|
|
if orgID.Valid {
|
|
orgRoles = append(orgRoles, role)
|
|
} else {
|
|
siteRoles = append(siteRoles, role)
|
|
}
|
|
}
|
|
|
|
// normalizedRoleName allows for the simple ElementsMatch to work properly.
|
|
normalizedRoleName := func(role database.CustomRole) string {
|
|
return role.Name + ":" + role.OrganizationID.UUID.String()
|
|
}
|
|
|
|
roleToLookup := func(role database.CustomRole) database.NameOrganizationPair {
|
|
return database.NameOrganizationPair{
|
|
Name: role.Name,
|
|
OrganizationID: role.OrganizationID.UUID,
|
|
}
|
|
}
|
|
|
|
testCases := []struct {
|
|
Name string
|
|
Params database.CustomRolesParams
|
|
Match func(role database.CustomRole) bool
|
|
}{
|
|
{
|
|
Name: "NilRoles",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: nil,
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: uuid.UUID{},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
// Empty params should return all roles
|
|
Name: "Empty",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{},
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: uuid.UUID{},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
Name: "Organization",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{},
|
|
ExcludeOrgRoles: false,
|
|
OrganizationID: orgIDs[1],
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.OrganizationID.UUID == orgIDs[1]
|
|
},
|
|
},
|
|
{
|
|
Name: "SpecificOrgRole",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID
|
|
},
|
|
},
|
|
{
|
|
Name: "SpecificSiteRole",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
OrganizationID: siteRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID
|
|
},
|
|
},
|
|
{
|
|
Name: "FewSpecificRoles",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: orgRoles[1].Name,
|
|
OrganizationID: orgRoles[1].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
OrganizationID: siteRoles[0].OrganizationID.UUID,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
|
|
(role.Name == orgRoles[1].Name && role.OrganizationID.UUID == orgRoles[1].OrganizationID.UUID) ||
|
|
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
|
|
},
|
|
},
|
|
{
|
|
Name: "AllRolesByLookup",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: slice.List(allRoles, roleToLookup),
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return true
|
|
},
|
|
},
|
|
{
|
|
Name: "NotExists",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.New(),
|
|
},
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return false
|
|
},
|
|
},
|
|
{
|
|
Name: "Mixed",
|
|
Params: database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.New(),
|
|
},
|
|
{
|
|
Name: "not-exists",
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
{
|
|
Name: orgRoles[0].Name,
|
|
OrganizationID: orgRoles[0].OrganizationID.UUID,
|
|
},
|
|
{
|
|
Name: siteRoles[0].Name,
|
|
},
|
|
},
|
|
},
|
|
Match: func(role database.CustomRole) bool {
|
|
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
|
|
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
found, err := db.CustomRoles(ctx, tc.Params)
|
|
require.NoError(t, err)
|
|
filtered := make([]database.CustomRole, 0)
|
|
for _, role := range allRoles {
|
|
if tc.Match(role) {
|
|
filtered = append(filtered, role)
|
|
}
|
|
}
|
|
|
|
a := slice.List(filtered, normalizedRoleName)
|
|
b := slice.List(found, normalizedRoleName)
|
|
require.Equal(t, a, b)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
systemRole, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: "test-system-role",
|
|
DisplayName: "",
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
SitePermissions: database.CustomRolePermissions{},
|
|
OrgPermissions: database.CustomRolePermissions{},
|
|
UserPermissions: database.CustomRolePermissions{},
|
|
MemberPermissions: database.CustomRolePermissions{},
|
|
IsSystem: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
nonSystemRole, err := db.InsertCustomRole(ctx, database.InsertCustomRoleParams{
|
|
Name: "test-custom-role",
|
|
DisplayName: "",
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
SitePermissions: database.CustomRolePermissions{},
|
|
OrgPermissions: database.CustomRolePermissions{},
|
|
UserPermissions: database.CustomRolePermissions{},
|
|
MemberPermissions: database.CustomRolePermissions{},
|
|
IsSystem: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.DeleteCustomRole(ctx, database.DeleteCustomRoleParams{
|
|
Name: systemRole.Name,
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.DeleteCustomRole(ctx, database.DeleteCustomRoleParams{
|
|
Name: nonSystemRole.Name,
|
|
OrganizationID: uuid.NullUUID{
|
|
UUID: org.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
roles, err := db.CustomRoles(ctx, database.CustomRolesParams{
|
|
LookupRoles: []database.NameOrganizationPair{
|
|
{
|
|
Name: systemRole.Name,
|
|
OrganizationID: org.ID,
|
|
},
|
|
{
|
|
Name: nonSystemRole.Name,
|
|
OrganizationID: org.ID,
|
|
},
|
|
},
|
|
IncludeSystemRoles: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, roles, 1)
|
|
require.Equal(t, systemRole.Name, roles[0].Name)
|
|
require.True(t, roles[0].IsSystem)
|
|
}
|
|
|
|
func TestGetAuthorizationUserRolesImpliedOrgRole(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
regularUser := dbgen.User(t, db, database.User{})
|
|
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: regularUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: saUser.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
wantMember := rbac.RoleOrgMember() + ":" + org.ID.String()
|
|
wantSA := rbac.RoleOrgServiceAccount() + ":" + org.ID.String()
|
|
|
|
// Regular users get the implied organization-member role.
|
|
regularRoles, err := db.GetAuthorizationUserRoles(ctx, regularUser.ID)
|
|
require.NoError(t, err)
|
|
require.Contains(t, regularRoles.Roles, wantMember)
|
|
require.NotContains(t, regularRoles.Roles, wantSA)
|
|
|
|
// Service accounts get the implied organization-service-account role.
|
|
saRoles, err := db.GetAuthorizationUserRoles(ctx, saUser.ID)
|
|
require.NoError(t, err)
|
|
require.Contains(t, saRoles.Roles, wantSA)
|
|
require.NotContains(t, saRoles.Roles, wantMember)
|
|
}
|
|
|
|
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
|
ID: org.ID,
|
|
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
|
UpdatedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ShareableWorkspaceOwnersNone, updated.ShareableWorkspaceOwners)
|
|
|
|
got, err := db.GetOrganizationByID(ctx, org.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.ShareableWorkspaceOwnersNone, got.ShareableWorkspaceOwners)
|
|
}
|
|
|
|
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("DeletesAll", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org1 := dbgen.Organization(t, db, database.Organization{})
|
|
org2 := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
owner1 := dbgen.User(t, db, database.User{})
|
|
owner2 := dbgen.User(t, db, database.User{})
|
|
sharedUser := dbgen.User(t, db, database.User{})
|
|
sharedGroup := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: org1.ID,
|
|
})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: owner1.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org2.ID,
|
|
UserID: owner2.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: sharedUser.ID,
|
|
})
|
|
|
|
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: owner1.ID,
|
|
OrganizationID: org1.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
GroupACL: database.WorkspaceACL{
|
|
sharedGroup.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: owner2.ID,
|
|
OrganizationID: org2.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
uuid.NewString(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
|
OrganizationID: org1.ID,
|
|
ExcludeServiceAccounts: false,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, got1.UserACL)
|
|
require.Empty(t, got1.GroupACL)
|
|
|
|
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, got2.UserACL)
|
|
})
|
|
|
|
t.Run("ExcludesServiceAccounts", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
regularUser := dbgen.User(t, db, database.User{})
|
|
saUser := dbgen.User(t, db, database.User{IsServiceAccount: true})
|
|
sharedUser := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: regularUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: saUser.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org.ID,
|
|
UserID: sharedUser.ID,
|
|
})
|
|
|
|
regularWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: regularUser.ID,
|
|
OrganizationID: org.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
saWS := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: saUser.ID,
|
|
OrganizationID: org.ID,
|
|
UserACL: database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
},
|
|
}).Do().Workspace
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
err := db.DeleteWorkspaceACLsByOrganization(ctx, database.DeleteWorkspaceACLsByOrganizationParams{
|
|
OrganizationID: org.ID,
|
|
ExcludeServiceAccounts: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Regular user workspace ACLs should be cleared.
|
|
gotRegular, err := db.GetWorkspaceByID(ctx, regularWS.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, gotRegular.UserACL)
|
|
|
|
// Service account workspace ACLs should be preserved.
|
|
gotSA, err := db.GetWorkspaceByID(ctx, saWS.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.WorkspaceACL{
|
|
sharedUser.ID.String(): {
|
|
Permissions: []policy.Action{policy.ActionRead},
|
|
},
|
|
}, gotSA.UserACL)
|
|
})
|
|
}
|
|
|
|
func TestAuthorizedAuditLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var allLogs []database.AuditLog
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
db = dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
siteWideIDs := []uuid.UUID{uuid.New(), uuid.New()}
|
|
for _, id := range siteWideIDs {
|
|
allLogs = append(allLogs, dbgen.AuditLog(t, db, database.AuditLog{
|
|
ID: id,
|
|
OrganizationID: uuid.Nil,
|
|
}))
|
|
}
|
|
|
|
// This map is a simple way to insert a given number of organizations
|
|
// and audit logs for each organization.
|
|
// map[orgID][]AuditLogID
|
|
orgAuditLogs := map[uuid.UUID][]uuid.UUID{
|
|
uuid.New(): {uuid.New(), uuid.New()},
|
|
uuid.New(): {uuid.New(), uuid.New()},
|
|
}
|
|
orgIDs := make([]uuid.UUID, 0, len(orgAuditLogs))
|
|
for orgID := range orgAuditLogs {
|
|
orgIDs = append(orgIDs, orgID)
|
|
}
|
|
for orgID, ids := range orgAuditLogs {
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
for _, id := range ids {
|
|
allLogs = append(allLogs, dbgen.AuditLog(t, db, database.AuditLog{
|
|
ID: id,
|
|
OrganizationID: orgID,
|
|
}))
|
|
}
|
|
}
|
|
|
|
// Now fetch all the logs
|
|
auditorRole, err := rbac.RoleByName(rbac.RoleAuditor())
|
|
require.NoError(t, err)
|
|
|
|
memberRole, err := rbac.RoleByName(rbac.RoleMember())
|
|
require.NoError(t, err)
|
|
|
|
orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role {
|
|
t.Helper()
|
|
|
|
role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID))
|
|
require.NoError(t, err)
|
|
return role
|
|
}
|
|
|
|
t.Run("NoAccess", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is a member of 0 organizations
|
|
memberCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "member",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{memberRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(memberCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(memberCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: No logs returned and count is 0
|
|
require.Equal(t, int64(0), count, "count should be 0")
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
})
|
|
|
|
t.Run("SiteWideAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A site wide auditor
|
|
siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "owner",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{auditorRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: the auditor queries for audit logs
|
|
count, err := db.CountAuditLogs(siteAuditorCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(siteAuditorCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: All logs are returned and count matches
|
|
require.Equal(t, int64(len(allLogs)), count, "count should match total number of logs")
|
|
require.ElementsMatch(t, auditOnlyIDs(allLogs), auditOnlyIDs(logs), "all logs should be returned")
|
|
})
|
|
|
|
t.Run("SingleOrgAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
orgID := orgIDs[0]
|
|
// Given: An organization scoped auditor
|
|
orgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, orgID)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The auditor queries for audit logs
|
|
count, err := db.CountAuditLogs(orgAuditCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(orgAuditCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: Only the logs for the organization are returned and count matches
|
|
require.Equal(t, int64(len(orgAuditLogs[orgID])), count, "count should match organization logs")
|
|
require.ElementsMatch(t, orgAuditLogs[orgID], auditOnlyIDs(logs), "only organization logs should be returned")
|
|
})
|
|
|
|
t.Run("TwoOrgAuditors", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
first := orgIDs[0]
|
|
second := orgIDs[1]
|
|
// Given: A user who is an auditor for two organizations
|
|
multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(multiOrgAuditCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(multiOrgAuditCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: All logs for both organizations are returned and count matches
|
|
expectedLogs := append([]uuid.UUID{}, orgAuditLogs[first]...)
|
|
expectedLogs = append(expectedLogs, orgAuditLogs[second]...)
|
|
require.Equal(t, int64(len(expectedLogs)), count, "count should match sum of both organizations")
|
|
require.ElementsMatch(t, expectedLogs, auditOnlyIDs(logs), "logs from both organizations should be returned")
|
|
})
|
|
|
|
t.Run("ErroneousOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is an auditor for an organization that has 0 logs
|
|
userCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
count, err := db.CountAuditLogs(userCtx, database.CountAuditLogsParams{})
|
|
require.NoError(t, err)
|
|
logs, err := db.GetAuditLogsOffset(userCtx, database.GetAuditLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
|
|
// Then: No logs are returned and count is 0
|
|
require.Equal(t, int64(0), count, "count should be 0")
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
})
|
|
}
|
|
|
|
func auditOnlyIDs[T database.AuditLog | database.GetAuditLogsOffsetRow](logs []T) []uuid.UUID {
|
|
ids := make([]uuid.UUID, 0, len(logs))
|
|
for _, log := range logs {
|
|
switch log := any(log).(type) {
|
|
case database.AuditLog:
|
|
ids = append(ids, log.ID)
|
|
case database.GetAuditLogsOffsetRow:
|
|
ids = append(ids, log.AuditLog.ID)
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var allLogs []database.ConnectionLog
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
|
|
authDb := dbauthz.New(db, authz, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgA.Org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
wsID := uuid.New()
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
WorkspaceTransition: database.WorkspaceTransitionStart,
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
WorkspaceID: wsID,
|
|
})
|
|
|
|
// This map is a simple way to insert a given number of organizations
|
|
// and audit logs for each organization.
|
|
// map[orgID][]ConnectionLogID
|
|
orgConnectionLogs := map[uuid.UUID][]uuid.UUID{
|
|
orgA.Org.ID: {uuid.New(), uuid.New()},
|
|
orgB.Org.ID: {uuid.New(), uuid.New()},
|
|
}
|
|
orgIDs := make([]uuid.UUID, 0, len(orgConnectionLogs))
|
|
for orgID := range orgConnectionLogs {
|
|
orgIDs = append(orgIDs, orgID)
|
|
}
|
|
for orgID, ids := range orgConnectionLogs {
|
|
for _, id := range ids {
|
|
allLogs = append(allLogs, dbgen.ConnectionLog(t, authDb, database.UpsertConnectionLogParams{
|
|
WorkspaceID: wsID,
|
|
WorkspaceOwnerID: user.ID,
|
|
ID: id,
|
|
OrganizationID: orgID,
|
|
}))
|
|
}
|
|
}
|
|
|
|
// Now fetch all the logs
|
|
auditorRole, err := rbac.RoleByName(rbac.RoleAuditor())
|
|
require.NoError(t, err)
|
|
|
|
memberRole, err := rbac.RoleByName(rbac.RoleMember())
|
|
require.NoError(t, err)
|
|
|
|
orgAuditorRoles := func(t *testing.T, orgID uuid.UUID) rbac.Role {
|
|
t.Helper()
|
|
|
|
role, err := rbac.RoleByName(rbac.ScopedRoleOrgAuditor(orgID))
|
|
require.NoError(t, err)
|
|
return role
|
|
}
|
|
|
|
t.Run("NoAccess", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is a member of 0 organizations
|
|
memberCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "member",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{memberRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(memberCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: No logs returned
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("SiteWideAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A site wide auditor
|
|
siteAuditorCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "owner",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{auditorRole},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: the auditor queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(siteAuditorCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: All logs are returned
|
|
require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("SingleOrgAuditor", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
orgID := orgIDs[0]
|
|
// Given: An organization scoped auditor
|
|
orgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, orgID)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The auditor queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(orgAuditCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: Only the logs for the organization are returned
|
|
require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("TwoOrgAuditors", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
first := orgIDs[0]
|
|
second := orgIDs[1]
|
|
// Given: A user who is an auditor for two organizations
|
|
multiOrgAuditCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, first), orgAuditorRoles(t, second)},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for connection logs
|
|
logs, err := authDb.GetConnectionLogsOffset(multiOrgAuditCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: All logs for both organizations are returned
|
|
require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs))
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
|
|
t.Run("ErroneousOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A user who is an auditor for an organization that has 0 logs
|
|
userCtx := dbauthz.As(ctx, rbac.Subject{
|
|
FriendlyName: "org-auditor",
|
|
ID: uuid.NewString(),
|
|
Roles: rbac.Roles{orgAuditorRoles(t, uuid.New())},
|
|
Scope: rbac.ScopeAll,
|
|
})
|
|
|
|
// When: The user queries for audit logs
|
|
logs, err := authDb.GetConnectionLogsOffset(userCtx, database.GetConnectionLogsOffsetParams{})
|
|
require.NoError(t, err)
|
|
// Then: No logs are returned
|
|
require.Len(t, logs, 0, "no logs should be returned")
|
|
// And: The count matches the number of logs returned
|
|
count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, len(logs), count)
|
|
})
|
|
}
|
|
|
|
func TestCountConnectionLogs(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
userA := dbgen.User(t, db, database.User{})
|
|
tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID})
|
|
wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID})
|
|
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
userB := dbgen.User(t, db, database.User{})
|
|
tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID})
|
|
wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID})
|
|
|
|
// Create logs for two different orgs.
|
|
for i := 0; i < 20; i++ {
|
|
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
OrganizationID: wsA.OrganizationID,
|
|
WorkspaceOwnerID: wsA.OwnerID,
|
|
WorkspaceID: wsA.ID,
|
|
Type: database.ConnectionTypeSsh,
|
|
})
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
OrganizationID: wsB.OrganizationID,
|
|
WorkspaceOwnerID: wsB.OwnerID,
|
|
WorkspaceID: wsB.ID,
|
|
Type: database.ConnectionTypeSsh,
|
|
})
|
|
}
|
|
|
|
// Count with a filter for orgA.
|
|
countParams := database.CountConnectionLogsParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
}
|
|
totalCount, err := db.CountConnectionLogs(ctx, countParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(20), totalCount)
|
|
|
|
// Get a paginated result for the same filter.
|
|
getParams := database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
LimitOpt: 5,
|
|
OffsetOpt: 10,
|
|
}
|
|
logs, err := db.GetConnectionLogsOffset(ctx, getParams)
|
|
require.NoError(t, err)
|
|
require.Len(t, logs, 5)
|
|
|
|
// The count with the filter should remain the same, independent of pagination.
|
|
countAfterGet, err := db.CountConnectionLogs(ctx, countParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(20), countAfterGet)
|
|
}
|
|
|
|
func TestConnectionLogsOffsetFilters(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
orgB := dbfake.Organization(t, db).Do()
|
|
|
|
user1 := dbgen.User(t, db, database.User{
|
|
Username: "user1",
|
|
Email: "user1@test.com",
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{
|
|
Username: "user2",
|
|
Email: "user2@test.com",
|
|
})
|
|
user3 := dbgen.User(t, db, database.User{
|
|
Username: "user3",
|
|
Email: "user3@test.com",
|
|
})
|
|
|
|
ws1Tpl := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: user1.ID})
|
|
ws1 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: orgA.Org.ID,
|
|
TemplateID: ws1Tpl.ID,
|
|
})
|
|
ws2Tpl := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: user2.ID})
|
|
ws2 := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: orgB.Org.ID,
|
|
TemplateID: ws2Tpl.ID,
|
|
})
|
|
|
|
now := dbtime.Now()
|
|
log1ConnID := uuid.New()
|
|
log1 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-4 * time.Hour),
|
|
OrganizationID: ws1.OrganizationID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceName: ws1.Name,
|
|
Type: database.ConnectionTypeWorkspaceApp,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user1.ID, Valid: true},
|
|
UserAgent: sql.NullString{String: "Mozilla/5.0", Valid: true},
|
|
SlugOrPort: sql.NullString{String: "code-server", Valid: true},
|
|
ConnectionID: uuid.NullUUID{UUID: log1ConnID, Valid: true},
|
|
})
|
|
|
|
log2ConnID := uuid.New()
|
|
log2 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-3 * time.Hour),
|
|
OrganizationID: ws1.OrganizationID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceName: ws1.Name,
|
|
Type: database.ConnectionTypeVscode,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
ConnectionID: uuid.NullUUID{UUID: log2ConnID, Valid: true},
|
|
})
|
|
|
|
// Mark log2 as disconnected
|
|
log2 = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-2 * time.Hour),
|
|
ConnectionID: log2.ConnectionID,
|
|
WorkspaceID: ws1.ID,
|
|
WorkspaceOwnerID: ws1.OwnerID,
|
|
AgentName: log2.AgentName,
|
|
ConnectionStatus: database.ConnectionStatusDisconnected,
|
|
|
|
OrganizationID: log2.OrganizationID,
|
|
})
|
|
|
|
log3ConnID := uuid.New()
|
|
log3 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-2 * time.Hour),
|
|
OrganizationID: ws2.OrganizationID,
|
|
WorkspaceOwnerID: ws2.OwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
WorkspaceName: ws2.Name,
|
|
Type: database.ConnectionTypeSsh,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user2.ID, Valid: true},
|
|
ConnectionID: uuid.NullUUID{UUID: log3ConnID, Valid: true},
|
|
})
|
|
|
|
// Mark log3 as disconnected
|
|
log3 = dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-1 * time.Hour),
|
|
ConnectionID: log3.ConnectionID,
|
|
WorkspaceOwnerID: log3.WorkspaceOwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
AgentName: log3.AgentName,
|
|
ConnectionStatus: database.ConnectionStatusDisconnected,
|
|
|
|
OrganizationID: log3.OrganizationID,
|
|
})
|
|
|
|
log4 := dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
|
|
Time: now.Add(-1 * time.Hour),
|
|
OrganizationID: ws2.OrganizationID,
|
|
WorkspaceOwnerID: ws2.OwnerID,
|
|
WorkspaceID: ws2.ID,
|
|
WorkspaceName: ws2.Name,
|
|
Type: database.ConnectionTypeVscode,
|
|
ConnectionStatus: database.ConnectionStatusConnected,
|
|
UserID: uuid.NullUUID{UUID: user3.ID, Valid: true},
|
|
})
|
|
|
|
testCases := []struct {
|
|
name string
|
|
params database.GetConnectionLogsOffsetParams
|
|
expectedLogIDs []uuid.UUID
|
|
}{
|
|
{
|
|
name: "NoFilter",
|
|
params: database.GetConnectionLogsOffsetParams{},
|
|
expectedLogIDs: []uuid.UUID{
|
|
log1.ID, log2.ID, log3.ID, log4.ID,
|
|
},
|
|
},
|
|
{
|
|
name: "OrganizationID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgB.Org.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwner",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwner: user1.Username,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwnerID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwnerID: user1.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceOwnerEmail",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceOwnerEmail: user2.Email,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "Type",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Type: string(database.ConnectionTypeVscode),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "UserID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
UserID: user1.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "Username",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Username: user1.Username,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "UserEmail",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
UserEmail: user3.Email,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectedAfter",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectedAfter: now.Add(-90 * time.Minute), // 1.5 hours ago
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectedBefore",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectedBefore: now.Add(-150 * time.Minute),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID, log2.ID},
|
|
},
|
|
{
|
|
name: "WorkspaceID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
WorkspaceID: ws2.ID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log3.ID, log4.ID},
|
|
},
|
|
{
|
|
name: "ConnectionID",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
ConnectionID: log1.ConnectionID.UUID,
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log1.ID},
|
|
},
|
|
{
|
|
name: "StatusOngoing",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Status: string(codersdk.ConnectionLogStatusOngoing),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log4.ID},
|
|
},
|
|
{
|
|
name: "StatusCompleted",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
Status: string(codersdk.ConnectionLogStatusCompleted),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID, log3.ID},
|
|
},
|
|
{
|
|
name: "OrganizationAndTypeAndStatus",
|
|
params: database.GetConnectionLogsOffsetParams{
|
|
OrganizationID: orgA.Org.ID,
|
|
Type: string(database.ConnectionTypeVscode),
|
|
Status: string(codersdk.ConnectionLogStatusCompleted),
|
|
},
|
|
expectedLogIDs: []uuid.UUID{log2.ID},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
logs, err := db.GetConnectionLogsOffset(ctx, tc.params)
|
|
require.NoError(t, err)
|
|
count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{
|
|
OrganizationID: tc.params.OrganizationID,
|
|
WorkspaceOwner: tc.params.WorkspaceOwner,
|
|
Type: tc.params.Type,
|
|
UserID: tc.params.UserID,
|
|
Username: tc.params.Username,
|
|
UserEmail: tc.params.UserEmail,
|
|
ConnectedAfter: tc.params.ConnectedAfter,
|
|
ConnectedBefore: tc.params.ConnectedBefore,
|
|
WorkspaceID: tc.params.WorkspaceID,
|
|
ConnectionID: tc.params.ConnectionID,
|
|
Status: tc.params.Status,
|
|
WorkspaceOwnerID: tc.params.WorkspaceOwnerID,
|
|
WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs))
|
|
require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)")
|
|
})
|
|
}
|
|
}
|
|
|
|
func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffsetRow](logs []T) []uuid.UUID {
|
|
ids := make([]uuid.UUID, 0, len(logs))
|
|
for _, log := range logs {
|
|
switch log := any(log).(type) {
|
|
case database.ConnectionLog:
|
|
ids = append(ids, log.ID)
|
|
case database.GetConnectionLogsOffsetRow:
|
|
ids = append(ids, log.ConnectionLog.ID)
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func TestBatchUpsertConnectionLogs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
|
|
t.Helper()
|
|
u := dbgen.User(t, db, database.User{})
|
|
o := dbgen.Organization(t, db, database.Organization{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: o.ID,
|
|
CreatedBy: u.ID,
|
|
})
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
ID: uuid.New(),
|
|
OwnerID: u.ID,
|
|
OrganizationID: o.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
}
|
|
|
|
// zeroTime is the sentinel value that the SQL treats as "no
|
|
// connect/disconnect time provided".
|
|
zeroTime := time.Time{}
|
|
|
|
defaultIP := pqtype.Inet{
|
|
IPNet: net.IPNet{
|
|
IP: net.IPv4(127, 0, 0, 1),
|
|
Mask: net.IPv4Mask(255, 255, 255, 255),
|
|
},
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("SingleConnect", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime))
|
|
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid,
|
|
"disconnect_time should be NULL for a connect-only event")
|
|
})
|
|
|
|
t.Run("ConnectThenDisconnect", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
// Insert connect.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert disconnect for same connection.
|
|
disconnectTime := connectTime.Add(time.Second)
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{zeroTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{1},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"test disconnect"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
row := rows[0].ConnectionLog
|
|
require.True(t, connectTime.Equal(row.ConnectTime))
|
|
require.True(t, row.DisconnectTime.Valid)
|
|
require.True(t, disconnectTime.Equal(row.DisconnectTime.Time))
|
|
require.Equal(t, "test disconnect", row.DisconnectReason.String)
|
|
require.Equal(t, int32(1), row.Code.Int32)
|
|
})
|
|
|
|
t.Run("DuplicateConnectIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
connectTime := dbtime.Now()
|
|
|
|
mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams {
|
|
return database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{ct},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{ip},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
}
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP))
|
|
require.NoError(t, err)
|
|
|
|
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows1, 1)
|
|
|
|
// Second connect with later time and different IP.
|
|
otherIP := pqtype.Inet{
|
|
IPNet: net.IPNet{
|
|
IP: net.IPv4(10, 0, 0, 1),
|
|
Mask: net.IPv4Mask(255, 255, 255, 255),
|
|
},
|
|
Valid: true,
|
|
}
|
|
err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP))
|
|
require.NoError(t, err)
|
|
|
|
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows2, 1)
|
|
|
|
// The LEAST logic should pick the earlier connect_time; IP and
|
|
// other fields are not updated on conflict.
|
|
require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime),
|
|
"connect_time should remain the original (earlier) value")
|
|
})
|
|
|
|
t.Run("OrderIndependentConnectTime", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
connectTime := disconnectTime.Add(-5 * time.Second)
|
|
|
|
// Disconnect arrives first.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"bye"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Connect arrives second with the real (earlier) connect_time.
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{connectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime),
|
|
"LEAST should pick the earlier connect_time")
|
|
})
|
|
|
|
t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
|
|
mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams {
|
|
return database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{code},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{reason},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
}
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1))
|
|
require.NoError(t, err)
|
|
|
|
// Second disconnect with different reason and code.
|
|
err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2))
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
row := rows[0].ConnectionLog
|
|
require.Equal(t, "first reason", row.DisconnectReason.String,
|
|
"disconnect_reason should not be overwritten")
|
|
require.Equal(t, int32(1), row.Code.Int32,
|
|
"code should not be overwritten")
|
|
})
|
|
|
|
t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
disconnectTime := dbtime.Now()
|
|
|
|
// Insert disconnect first.
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{42},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"server shutdown"},
|
|
DisconnectTime: []time.Time{disconnectTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows1, 1)
|
|
require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid)
|
|
require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String)
|
|
require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32)
|
|
|
|
// Insert connect for same connection_id.
|
|
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{disconnectTime.Add(time.Second)},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows2, 1)
|
|
row := rows2[0].ConnectionLog
|
|
require.True(t, row.DisconnectTime.Valid,
|
|
"disconnect_time should not be cleared by a later connect")
|
|
require.Equal(t, "server shutdown", row.DisconnectReason.String,
|
|
"disconnect_reason should not be cleared")
|
|
require.Equal(t, int32(42), row.Code.Int32,
|
|
"code should not be cleared")
|
|
})
|
|
|
|
t.Run("CodeZeroPreserved", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
now := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{0},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{"normal"},
|
|
DisconnectTime: []time.Time{now},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL")
|
|
require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32,
|
|
"code=0 should be preserved, not treated as NULL")
|
|
})
|
|
|
|
t.Run("CodeNullWhenInvalid", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
connID := uuid.New()
|
|
now := dbtime.Now()
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{99},
|
|
CodeValid: []bool{false},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{""},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{""},
|
|
ConnectionID: []uuid.UUID{connID},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 1)
|
|
require.False(t, rows[0].ConnectionLog.Code.Valid,
|
|
"code should be NULL when code_valid is false")
|
|
})
|
|
|
|
t.Run("NullConnectionIDEvents", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
now := dbtime.Now()
|
|
|
|
// Insert two web events with NULL connection_id (uuid.Nil →
|
|
// NULL via NULLIF) for the same workspace/agent.
|
|
for i := range 2 {
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: []uuid.UUID{uuid.New()},
|
|
ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)},
|
|
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
|
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
|
WorkspaceID: []uuid.UUID{ws.ID},
|
|
WorkspaceName: []string{ws.Name},
|
|
AgentName: []string{"agent"},
|
|
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
|
Code: []int32{200},
|
|
CodeValid: []bool{true},
|
|
Ip: []pqtype.Inet{defaultIP},
|
|
UserAgent: []string{"Mozilla/5.0"},
|
|
UserID: []uuid.UUID{uuid.Nil},
|
|
SlugOrPort: []string{"web-terminal"},
|
|
ConnectionID: []uuid.UUID{uuid.Nil},
|
|
DisconnectReason: []string{""},
|
|
DisconnectTime: []time.Time{zeroTime},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, 2,
|
|
"NULL connection_id rows should not conflict with each other")
|
|
})
|
|
|
|
t.Run("MultipleIndependentConnections", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
ws := createWorkspace(t, db)
|
|
now := dbtime.Now()
|
|
|
|
n := 5
|
|
ids := make([]uuid.UUID, n)
|
|
connectTimes := make([]time.Time, n)
|
|
orgIDs := make([]uuid.UUID, n)
|
|
ownerIDs := make([]uuid.UUID, n)
|
|
wsIDs := make([]uuid.UUID, n)
|
|
wsNames := make([]string, n)
|
|
agentNames := make([]string, n)
|
|
types := make([]database.ConnectionType, n)
|
|
codes := make([]int32, n)
|
|
codeValids := make([]bool, n)
|
|
ips := make([]pqtype.Inet, n)
|
|
userAgents := make([]string, n)
|
|
userIDs := make([]uuid.UUID, n)
|
|
slugOrPorts := make([]string, n)
|
|
connIDs := make([]uuid.UUID, n)
|
|
disconnectReasons := make([]string, n)
|
|
disconnectTimes := make([]time.Time, n)
|
|
|
|
for i := range n {
|
|
ids[i] = uuid.New()
|
|
connectTimes[i] = now.Add(time.Duration(i) * time.Second)
|
|
orgIDs[i] = ws.OrganizationID
|
|
ownerIDs[i] = ws.OwnerID
|
|
wsIDs[i] = ws.ID
|
|
wsNames[i] = ws.Name
|
|
agentNames[i] = "agent"
|
|
types[i] = database.ConnectionTypeSsh
|
|
codes[i] = 0
|
|
codeValids[i] = false
|
|
ips[i] = defaultIP
|
|
userAgents[i] = ""
|
|
userIDs[i] = uuid.Nil
|
|
slugOrPorts[i] = ""
|
|
connIDs[i] = uuid.New()
|
|
disconnectReasons[i] = ""
|
|
disconnectTimes[i] = zeroTime
|
|
}
|
|
|
|
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
|
ID: ids,
|
|
ConnectTime: connectTimes,
|
|
OrganizationID: orgIDs,
|
|
WorkspaceOwnerID: ownerIDs,
|
|
WorkspaceID: wsIDs,
|
|
WorkspaceName: wsNames,
|
|
AgentName: agentNames,
|
|
Type: types,
|
|
Code: codes,
|
|
CodeValid: codeValids,
|
|
Ip: ips,
|
|
UserAgent: userAgents,
|
|
UserID: userIDs,
|
|
SlugOrPort: slugOrPorts,
|
|
ConnectionID: connIDs,
|
|
DisconnectReason: disconnectReasons,
|
|
DisconnectTime: disconnectTimes,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
|
require.NoError(t, err)
|
|
require.Len(t, rows, n, "each unique connection_id should produce its own row")
|
|
})
|
|
}
|
|
|
|
type tvArgs struct {
|
|
Status database.ProvisionerJobStatus
|
|
// CreateWorkspace is true if we should create a workspace for the template version
|
|
CreateWorkspace bool
|
|
WorkspaceID uuid.UUID
|
|
CreateAgent bool
|
|
WorkspaceTransition database.WorkspaceTransition
|
|
ExtraAgents int
|
|
ExtraBuilds int
|
|
}
|
|
|
|
// createTemplateVersion is a helper function to create a version with its dependencies.
|
|
func createTemplateVersion(t testing.TB, db database.Store, tpl database.Template, args tvArgs) database.TemplateVersion {
|
|
t.Helper()
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tpl.OrganizationID,
|
|
CreatedAt: dbtime.Now(),
|
|
UpdatedAt: dbtime.Now(),
|
|
CreatedBy: tpl.CreatedBy,
|
|
})
|
|
|
|
latestJob := database.ProvisionerJob{
|
|
ID: version.JobID,
|
|
Error: sql.NullString{},
|
|
OrganizationID: tpl.OrganizationID,
|
|
InitiatorID: tpl.CreatedBy,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
if args.CreateWorkspace {
|
|
wrk := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
ID: args.WorkspaceID,
|
|
CreatedAt: time.Time{},
|
|
UpdatedAt: time.Time{},
|
|
OwnerID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
trans := database.WorkspaceTransitionStart
|
|
if args.WorkspaceTransition != "" {
|
|
trans = args.WorkspaceTransition
|
|
}
|
|
latestJob = database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
latestJob = dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
latestResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: latestJob.ID,
|
|
})
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: wrk.ID,
|
|
TemplateVersionID: version.ID,
|
|
BuildNumber: 1,
|
|
Transition: trans,
|
|
InitiatorID: tpl.CreatedBy,
|
|
JobID: latestJob.ID,
|
|
})
|
|
for i := 0; i < args.ExtraBuilds; i++ {
|
|
latestJob = database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: tpl.CreatedBy,
|
|
OrganizationID: tpl.OrganizationID,
|
|
}
|
|
setJobStatus(t, args.Status, &latestJob)
|
|
latestJob = dbgen.ProvisionerJob(t, db, nil, latestJob)
|
|
latestResource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: latestJob.ID,
|
|
})
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: wrk.ID,
|
|
TemplateVersionID: version.ID,
|
|
// #nosec G115 - Safe conversion as build number is expected to be within int32 range
|
|
BuildNumber: int32(i) + 2,
|
|
Transition: trans,
|
|
InitiatorID: tpl.CreatedBy,
|
|
JobID: latestJob.ID,
|
|
})
|
|
}
|
|
|
|
if args.CreateAgent {
|
|
dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: latestResource.ID,
|
|
})
|
|
}
|
|
for i := 0; i < args.ExtraAgents; i++ {
|
|
dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: latestResource.ID,
|
|
})
|
|
}
|
|
}
|
|
return version
|
|
}
|
|
|
|
func setJobStatus(t testing.TB, status database.ProvisionerJobStatus, j *database.ProvisionerJob) {
|
|
t.Helper()
|
|
|
|
earlier := sql.NullTime{
|
|
Time: dbtime.Now().Add(time.Second * -30),
|
|
Valid: true,
|
|
}
|
|
now := sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
switch status {
|
|
case database.ProvisionerJobStatusRunning:
|
|
j.StartedAt = earlier
|
|
case database.ProvisionerJobStatusPending:
|
|
case database.ProvisionerJobStatusFailed:
|
|
j.StartedAt = earlier
|
|
j.CompletedAt = now
|
|
j.Error = sql.NullString{
|
|
String: "failed",
|
|
Valid: true,
|
|
}
|
|
j.ErrorCode = sql.NullString{
|
|
String: "failed",
|
|
Valid: true,
|
|
}
|
|
case database.ProvisionerJobStatusSucceeded:
|
|
j.StartedAt = earlier
|
|
j.CompletedAt = now
|
|
default:
|
|
t.Fatalf("invalid status: %s", status)
|
|
}
|
|
}
|
|
|
|
func TestArchiveVersions(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Run("ArchiveFailedVersions", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
// Create some versions
|
|
failed := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusFailed,
|
|
CreateWorkspace: false,
|
|
})
|
|
unused := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: false,
|
|
})
|
|
createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
})
|
|
deleted := createTemplateVersion(t, db, tpl, tvArgs{
|
|
Status: database.ProvisionerJobStatusSucceeded,
|
|
CreateWorkspace: true,
|
|
WorkspaceTransition: database.WorkspaceTransitionDelete,
|
|
})
|
|
|
|
// Now archive failed versions
|
|
archived, err := db.ArchiveUnusedTemplateVersions(ctx, database.ArchiveUnusedTemplateVersionsParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
TemplateID: tpl.ID,
|
|
// All versions
|
|
TemplateVersionID: uuid.Nil,
|
|
JobStatus: database.NullProvisionerJobStatus{
|
|
ProvisionerJobStatus: database.ProvisionerJobStatusFailed,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err, "archive failed versions")
|
|
require.Len(t, archived, 1, "should only archive one version")
|
|
require.Equal(t, failed.ID, archived[0], "should archive failed version")
|
|
|
|
// Archive all unused versions
|
|
archived, err = db.ArchiveUnusedTemplateVersions(ctx, database.ArchiveUnusedTemplateVersionsParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
TemplateID: tpl.ID,
|
|
// All versions
|
|
TemplateVersionID: uuid.Nil,
|
|
})
|
|
require.NoError(t, err, "archive failed versions")
|
|
require.Len(t, archived, 2)
|
|
require.ElementsMatch(t, []uuid.UUID{deleted.ID, unused.ID}, archived, "should archive unused versions")
|
|
})
|
|
}
|
|
|
|
func TestExpectOne(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Run("ErrNoRows", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
_, err = database.ExpectOne(db.GetUsers(ctx, database.GetUsersParams{}))
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
})
|
|
|
|
t.Run("TooMany", func(t *testing.T) {
|
|
t.Parallel()
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := context.Background()
|
|
|
|
// Create 2 organizations so the query returns >1
|
|
dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Organizations is an easy table without foreign key dependencies
|
|
_, err = database.ExpectOne(db.GetOrganizations(ctx, database.GetOrganizationsParams{}))
|
|
require.ErrorContains(t, err, "too many rows returned")
|
|
})
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
jobTags []database.StringMap
|
|
daemonTags []database.StringMap
|
|
queueSizes []int64
|
|
queuePositions []int64
|
|
// GetProvisionerJobsByIDsWithQueuePosition takes jobIDs as a parameter.
|
|
// If skipJobIDs is empty, all jobs are passed to the function; otherwise, the specified jobs are skipped.
|
|
// NOTE: Skipping job IDs means they will be excluded from the result,
|
|
// but this should not affect the queue position or queue size of other jobs.
|
|
skipJobIDs map[int]struct{}
|
|
}{
|
|
// Baseline test case
|
|
{
|
|
name: "test-case-1",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
},
|
|
queueSizes: []int64{2, 2, 0},
|
|
queuePositions: []int64{1, 1, 0},
|
|
},
|
|
// Includes an additional provisioner
|
|
{
|
|
name: "test-case-2",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3, 3},
|
|
queuePositions: []int64{1, 1, 3},
|
|
},
|
|
// Skips job at index 0
|
|
{
|
|
name: "test-case-3",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 3},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
},
|
|
},
|
|
// Skips job at index 1
|
|
{
|
|
name: "test-case-4",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 3},
|
|
skipJobIDs: map[int]struct{}{
|
|
1: {},
|
|
},
|
|
},
|
|
// Skips job at index 2
|
|
{
|
|
name: "test-case-5",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3, 3},
|
|
queuePositions: []int64{1, 1},
|
|
skipJobIDs: map[int]struct{}{
|
|
2: {},
|
|
},
|
|
},
|
|
// Skips jobs at indexes 0 and 2
|
|
{
|
|
name: "test-case-6",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{3},
|
|
queuePositions: []int64{1},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
2: {},
|
|
},
|
|
},
|
|
// Includes two additional jobs that any provisioner can execute.
|
|
{
|
|
name: "test-case-7",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{5, 5, 5, 5, 5},
|
|
queuePositions: []int64{1, 2, 3, 3, 5},
|
|
},
|
|
// Includes two additional jobs that any provisioner can execute, but they are intentionally skipped.
|
|
{
|
|
name: "test-case-8",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "c": "3"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2", "c": "3"},
|
|
},
|
|
queueSizes: []int64{5, 5, 5},
|
|
queuePositions: []int64{3, 3, 5},
|
|
skipJobIDs: map[int]struct{}{
|
|
0: {},
|
|
1: {},
|
|
},
|
|
},
|
|
// N jobs (1 job with 0 tags) & 0 provisioners exist
|
|
{
|
|
name: "test-case-9",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{},
|
|
queueSizes: []int64{0, 0, 0},
|
|
queuePositions: []int64{0, 0, 0},
|
|
},
|
|
// N jobs (1 job with 0 tags) & N provisioners
|
|
{
|
|
name: "test-case-10",
|
|
jobTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
queueSizes: []int64{2, 2, 2},
|
|
queuePositions: []int64{1, 2, 2},
|
|
},
|
|
// (N + 1) jobs (1 job with 0 tags) & N provisioners
|
|
// 1 job not matching any provisioner (first in the list)
|
|
{
|
|
name: "test-case-11",
|
|
jobTags: []database.StringMap{
|
|
{"c": "3"},
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{},
|
|
{"a": "1"},
|
|
{"b": "2"},
|
|
},
|
|
queueSizes: []int64{0, 2, 2, 2},
|
|
queuePositions: []int64{0, 1, 2, 2},
|
|
},
|
|
// 0 jobs & 0 provisioners
|
|
{
|
|
name: "test-case-12",
|
|
jobTags: []database.StringMap{},
|
|
daemonTags: []database.StringMap{},
|
|
queueSizes: nil, // TODO(yevhenii): should it be empty array instead?
|
|
queuePositions: nil,
|
|
},
|
|
// Many daemons with identical tags should produce same results as one.
|
|
{
|
|
name: "duplicate-daemons-same-tags",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"a": "1", "b": "2"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1", "b": "2"},
|
|
{"a": "1", "b": "2"},
|
|
},
|
|
queueSizes: []int64{2, 2},
|
|
queuePositions: []int64{1, 2},
|
|
},
|
|
// Jobs that don't match any queried job's daemon should still
|
|
// have correct queue positions.
|
|
{
|
|
name: "irrelevant-daemons-filtered",
|
|
jobTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"x": "9"},
|
|
},
|
|
daemonTags: []database.StringMap{
|
|
{"a": "1"},
|
|
{"x": "9"},
|
|
},
|
|
queueSizes: []int64{1},
|
|
queuePositions: []int64{1},
|
|
skipJobIDs: map[int]struct{}{1: {}},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create provisioner jobs based on provided tags:
|
|
allJobs := make([]database.ProvisionerJob, len(tc.jobTags))
|
|
for idx, tags := range tc.jobTags {
|
|
// Make sure jobs are stored in correct order, first job should have the earliest createdAt timestamp.
|
|
// Example for 3 jobs:
|
|
// job_1 createdAt: now - 3 minutes
|
|
// job_2 createdAt: now - 2 minutes
|
|
// job_3 createdAt: now - 1 minute
|
|
timeOffsetInMinutes := len(tc.jobTags) - idx
|
|
timeOffset := time.Duration(timeOffsetInMinutes) * time.Minute
|
|
createdAt := now.Add(-timeOffset)
|
|
|
|
allJobs[idx] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: createdAt,
|
|
Tags: tags,
|
|
})
|
|
}
|
|
|
|
// Create provisioner daemons based on provided tags:
|
|
for idx, tags := range tc.daemonTags {
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: fmt.Sprintf("prov_%v", idx),
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: tags,
|
|
})
|
|
}
|
|
|
|
// Assert invariant: the jobs are in pending status
|
|
for idx, job := range allJobs {
|
|
require.Equal(t, database.ProvisionerJobStatusPending, job.JobStatus, "expected job %d to have status %s", idx, database.ProvisionerJobStatusPending)
|
|
}
|
|
|
|
filteredJobs := make([]database.ProvisionerJob, 0)
|
|
filteredJobIDs := make([]uuid.UUID, 0)
|
|
for idx, job := range allJobs {
|
|
if _, skip := tc.skipJobIDs[idx]; skip {
|
|
continue
|
|
}
|
|
|
|
filteredJobs = append(filteredJobs, job)
|
|
filteredJobIDs = append(filteredJobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: filteredJobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(filteredJobs), "should return all unskipped jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(filteredJobs, func(i, j int) bool {
|
|
return filteredJobs[i].CreatedAt.Before(filteredJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, filteredJobs[idx], job.ProvisionerJob)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, tc.queueSizes, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, tc.queuePositions, queuePositions, "expected queue positions to be set correctly")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_MixedStatuses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create the following provisioner jobs:
|
|
allJobs := []database.ProvisionerJob{
|
|
// Pending. This will be the last in the queue because
|
|
// it was created most recently.
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
// Ensure the `tags` field is NOT NULL for both provisioner jobs and provisioner daemons;
|
|
// otherwise, provisioner daemons won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Another pending. This will come first in the queue
|
|
// because it was created before the previous job.
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-2 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Running
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-3 * time.Minute),
|
|
StartedAt: sql.NullTime{Valid: true, Time: now},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Succeeded
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-4 * time.Minute),
|
|
StartedAt: sql.NullTime{Valid: true, Time: now},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: now},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Canceling
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-5 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: now},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Canceled
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-6 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: now},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: now},
|
|
Error: sql.NullString{},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
// Failed
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-7 * time.Minute),
|
|
StartedAt: sql.NullTime{},
|
|
CanceledAt: sql.NullTime{},
|
|
CompletedAt: sql.NullTime{},
|
|
Error: sql.NullString{String: "failed", Valid: true},
|
|
Tags: database.StringMap{},
|
|
}),
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
// Assert invariant: the jobs are in the expected order
|
|
require.Len(t, allJobs, 7, "expected 7 jobs")
|
|
for idx, status := range []database.ProvisionerJobStatus{
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusRunning,
|
|
database.ProvisionerJobStatusSucceeded,
|
|
database.ProvisionerJobStatusCanceling,
|
|
database.ProvisionerJobStatusCanceled,
|
|
database.ProvisionerJobStatusFailed,
|
|
} {
|
|
require.Equal(t, status, allJobs[idx].JobStatus, "expected job %d to have status %s", idx, status)
|
|
}
|
|
|
|
var jobIDs []uuid.UUID
|
|
for _, job := range allJobs {
|
|
jobIDs = append(jobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(allJobs, func(i, j int) bool {
|
|
return allJobs[i].CreatedAt.Before(allJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, allJobs[idx], job.ProvisionerJob)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, []int64{0, 0, 0, 0, 0, 2, 2}, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, []int64{0, 0, 0, 0, 0, 1, 2}, queuePositions, "expected queue positions to be set correctly")
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create the following provisioner jobs:
|
|
allJobs := []database.ProvisionerJob{
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-4 * time.Minute),
|
|
// Ensure the `tags` field is NOT NULL for both provisioner jobs and provisioner daemons;
|
|
// otherwise, provisioner daemons won't be able to pick up any jobs.
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-5 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-6 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-3 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-2 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
|
|
dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-1 * time.Minute),
|
|
Tags: database.StringMap{},
|
|
}),
|
|
}
|
|
|
|
// Create default provisioner daemon:
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: "default_provisioner",
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{},
|
|
})
|
|
|
|
// Assert invariant: the jobs are in the expected order
|
|
require.Len(t, allJobs, 6, "expected 7 jobs")
|
|
for idx, status := range []database.ProvisionerJobStatus{
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
database.ProvisionerJobStatusPending,
|
|
} {
|
|
require.Equal(t, status, allJobs[idx].JobStatus, "expected job %d to have status %s", idx, status)
|
|
}
|
|
|
|
var jobIDs []uuid.UUID
|
|
for _, job := range allJobs {
|
|
jobIDs = append(jobIDs, job.ID)
|
|
}
|
|
|
|
// When: we fetch the jobs by their IDs
|
|
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
|
|
|
|
// Then: the jobs should be returned in the correct order (sorted by createdAt)
|
|
sort.Slice(allJobs, func(i, j int) bool {
|
|
return allJobs[i].CreatedAt.Before(allJobs[j].CreatedAt)
|
|
})
|
|
for idx, job := range actualJobs {
|
|
assert.EqualValues(t, allJobs[idx], job.ProvisionerJob)
|
|
assert.EqualValues(t, allJobs[idx].CreatedAt, job.ProvisionerJob.CreatedAt)
|
|
}
|
|
|
|
// Then: the queue size should be set correctly
|
|
var queueSizes []int64
|
|
for _, job := range actualJobs {
|
|
queueSizes = append(queueSizes, job.QueueSize)
|
|
}
|
|
assert.EqualValues(t, []int64{6, 6, 6, 6, 6, 6}, queueSizes, "expected queue positions to be set correctly")
|
|
|
|
// Then: the queue position should be set correctly:
|
|
var queuePositions []int64
|
|
for _, job := range actualJobs {
|
|
queuePositions = append(queuePositions, job.QueuePosition)
|
|
}
|
|
assert.EqualValues(t, []int64{1, 2, 3, 4, 5, 6}, queuePositions, "expected queue positions to be set correctly")
|
|
}
|
|
|
|
func TestGetProvisionerJobsByIDsWithQueuePosition_DuplicateDaemons(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Create 3 pending jobs with the same tags.
|
|
jobs := make([]database.ProvisionerJob, 3)
|
|
for i := range jobs {
|
|
jobs[i] = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
CreatedAt: now.Add(-time.Duration(3-i) * time.Minute),
|
|
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
|
})
|
|
}
|
|
|
|
// Create 50 daemons with identical tags (simulates scale).
|
|
for i := range 50 {
|
|
dbgen.ProvisionerDaemon(t, db, database.ProvisionerDaemon{
|
|
Name: fmt.Sprintf("daemon_%d", i),
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{"scope": "organization", "owner": ""},
|
|
})
|
|
}
|
|
|
|
jobIDs := make([]uuid.UUID, len(jobs))
|
|
for i, j := range jobs {
|
|
jobIDs[i] = j.ID
|
|
}
|
|
|
|
results, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx,
|
|
database.GetProvisionerJobsByIDsWithQueuePositionParams{
|
|
IDs: jobIDs,
|
|
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 3)
|
|
|
|
// All daemons have identical tags, so queue should be same as
|
|
// if there were just one daemon.
|
|
for i, r := range results {
|
|
assert.Equal(t, int64(3), r.QueueSize, "job %d queue size", i)
|
|
assert.Equal(t, int64(i+1), r.QueuePosition, "job %d queue position", i)
|
|
}
|
|
}
|
|
|
|
func TestGroupRemovalTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbgen.Organization(t, db, database.Organization{})
|
|
_, err := db.InsertAllUsersGroup(context.Background(), orgA.ID)
|
|
require.NoError(t, err)
|
|
|
|
orgB := dbgen.Organization(t, db, database.Organization{})
|
|
_, err = db.InsertAllUsersGroup(context.Background(), orgB.ID)
|
|
require.NoError(t, err)
|
|
|
|
orgs := []database.Organization{orgA, orgB}
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
extra := dbgen.User(t, db, database.User{})
|
|
users := []database.User{user, extra}
|
|
|
|
groupA1 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.ID,
|
|
})
|
|
groupA2 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.ID,
|
|
})
|
|
|
|
groupB1 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgB.ID,
|
|
})
|
|
groupB2 := dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgB.ID,
|
|
})
|
|
|
|
groups := []database.Group{groupA1, groupA2, groupB1, groupB2}
|
|
|
|
// Add users to all organizations
|
|
for _, u := range users {
|
|
for _, o := range orgs {
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: o.ID,
|
|
UserID: u.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add users to all groups
|
|
for _, u := range users {
|
|
for _, g := range groups {
|
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
|
GroupID: g.ID,
|
|
UserID: u.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Verify user is in all groups
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
onlyGroupIDs := func(row database.GetGroupsRow) uuid.UUID {
|
|
return row.Group.ID
|
|
}
|
|
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgA.ID, orgB.ID, // Everyone groups
|
|
groupA1.ID, groupA2.ID, groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(userGroups, onlyGroupIDs))
|
|
|
|
// Remove the user from org A
|
|
err = db.DeleteOrganizationMember(ctx, database.DeleteOrganizationMemberParams{
|
|
OrganizationID: orgA.ID,
|
|
UserID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify user is no longer in org A groups
|
|
userGroups, err = db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgB.ID, // Everyone group
|
|
groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(userGroups, onlyGroupIDs))
|
|
|
|
// Verify extra user is unchanged
|
|
extraUserGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
|
HasMemberID: extra.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.ElementsMatch(t, []uuid.UUID{
|
|
orgA.ID, orgB.ID, // Everyone groups
|
|
groupA1.ID, groupA2.ID, groupB1.ID, groupB2.ID, // Org groups
|
|
}, slice.List(extraUserGroups, onlyGroupIDs))
|
|
}
|
|
|
|
func TestGetUserStatusCounts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type testCase struct {
|
|
timezone string
|
|
location *time.Location
|
|
reportFrom time.Time
|
|
reportUntil time.Time
|
|
}
|
|
testCases := []testCase{}
|
|
|
|
// GetUserStatusCounts is sensitive to DST transitions, because it generates timestamps exactly
|
|
// one day apart from one another, and specific days can have varying lengths depending on the timezone.
|
|
// Therefore, we test with a variety of timezones.
|
|
timezones := []string{
|
|
"America/St_Johns",
|
|
"Africa/Johannesburg",
|
|
"America/New_York",
|
|
"Europe/London",
|
|
"Asia/Tokyo",
|
|
"Australia/Sydney",
|
|
}
|
|
|
|
// assemble test cases
|
|
for _, tz := range timezones {
|
|
location, err := time.LoadLocation(tz)
|
|
if err != nil {
|
|
t.Fatalf("failed to load location: %v", err)
|
|
}
|
|
|
|
// Testing based on the current system date will flake due to DST transitions.
|
|
// Instead, we test with a fixed range of dates that is large enough to span multiple DST transitions.
|
|
startOfTestDateRange := time.Date(2025, 1, 1, 0, 0, 0, 0, location)
|
|
endOfTestDateRange := time.Date(2026, 1, 1, 0, 0, 0, 0, location)
|
|
// To keep the number of test cases manageable given the large date range,
|
|
// we test with a suitable large interval. This interval is also the length of each report.
|
|
// this ensures we have full coverage of the date range.
|
|
testDateRangeInterval := 60
|
|
|
|
for reportFrom := startOfTestDateRange; !reportFrom.After(endOfTestDateRange); reportFrom = reportFrom.AddDate(0, 0, testDateRangeInterval) {
|
|
testCases = append(testCases, testCase{
|
|
timezone: tz,
|
|
location: location,
|
|
reportFrom: dbtime.Time(reportFrom),
|
|
reportUntil: dbtime.Time(reportFrom.AddDate(0, 0, testDateRangeInterval)),
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(fmt.Sprintf("%s/%s", tc.timezone, tc.reportUntil.Format("2006-01-02T15:04:05Z")), func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
userCreatedAt := tc.reportUntil.AddDate(0, 0, -60)
|
|
firstStatusChange := userCreatedAt.AddDate(0, 0, 29)
|
|
secondStatusChange := firstStatusChange.AddDate(0, 0, 29)
|
|
|
|
t.Run("No Users", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
counts, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: tc.reportFrom,
|
|
EndTime: tc.reportUntil,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, counts, "should return no results when there are no users")
|
|
})
|
|
|
|
t.Run("One User/Creation Only", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
subTestCases := []struct {
|
|
name string
|
|
status database.UserStatus
|
|
}{
|
|
{
|
|
name: "Active Only",
|
|
status: database.UserStatusActive,
|
|
},
|
|
{
|
|
name: "Dormant Only",
|
|
status: database.UserStatusDormant,
|
|
},
|
|
{
|
|
name: "Suspended Only",
|
|
status: database.UserStatusSuspended,
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
dbgen.User(t, db, database.User{
|
|
Status: stc.status,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
startTime := dbtime.StartOfDay(userCreatedAt)
|
|
endTime := dbtime.StartOfDay(tc.reportUntil)
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: startTime,
|
|
EndTime: endTime,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
numDays := 0
|
|
for d := startTime; !d.After(endTime); d = d.AddDate(0, 0, 1) {
|
|
numDays++
|
|
}
|
|
assert.Len(
|
|
t,
|
|
userStatusChanges,
|
|
numDays,
|
|
"should have 1 entry per day between the start and end time, including the end time",
|
|
)
|
|
|
|
for i, row := range userStatusChanges {
|
|
require.Equal(t, stc.status, row.Status, "should have the correct status")
|
|
|
|
rowDate := row.Date.In(tc.location)
|
|
expectedDate := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i)
|
|
assert.True(
|
|
t,
|
|
rowDate.Equal(expectedDate),
|
|
"expected date %s, but got %s for row %n",
|
|
expectedDate.String(),
|
|
rowDate.String(),
|
|
i,
|
|
)
|
|
|
|
if row.Date.Before(userCreatedAt) {
|
|
assert.Equal(t, int64(0), row.Count, "should have 0 users before creation")
|
|
} else {
|
|
assert.Equal(t, int64(1), row.Count, "should have 1 user after creation")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("One User/One Transition", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
subTestCases := []struct {
|
|
name string
|
|
initialStatus database.UserStatus
|
|
targetStatus database.UserStatus
|
|
expectedCounts map[time.Time]map[database.UserStatus]int64
|
|
}{
|
|
{
|
|
name: "Active to Dormant",
|
|
initialStatus: database.UserStatusActive,
|
|
targetStatus: database.UserStatusDormant,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Active to Suspended",
|
|
initialStatus: database.UserStatusActive,
|
|
targetStatus: database.UserStatusSuspended,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant to Active",
|
|
initialStatus: database.UserStatusDormant,
|
|
targetStatus: database.UserStatusActive,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant to Suspended",
|
|
initialStatus: database.UserStatusDormant,
|
|
targetStatus: database.UserStatusSuspended,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended to Active",
|
|
initialStatus: database.UserStatusSuspended,
|
|
targetStatus: database.UserStatusActive,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusActive: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusActive: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended to Dormant",
|
|
initialStatus: database.UserStatusSuspended,
|
|
targetStatus: database.UserStatusDormant,
|
|
expectedCounts: map[time.Time]map[database.UserStatus]int64{
|
|
userCreatedAt: {
|
|
database.UserStatusSuspended: 1,
|
|
database.UserStatusDormant: 0,
|
|
},
|
|
firstStatusChange: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
tc.reportUntil: {
|
|
database.UserStatusDormant: 1,
|
|
database.UserStatusSuspended: 0,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: stc.initialStatus,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
user, err := db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user.ID,
|
|
Status: stc.targetStatus,
|
|
UpdatedAt: firstStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, row := range userStatusChanges {
|
|
rowDate := row.Date.In(tc.location)
|
|
expectedDate := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i/2)
|
|
require.True(
|
|
t,
|
|
rowDate.Equal(expectedDate),
|
|
"expected date %s, but got %s for row %n",
|
|
expectedDate.String(),
|
|
rowDate.String(),
|
|
i,
|
|
)
|
|
switch {
|
|
case row.Date.Before(userCreatedAt):
|
|
require.Equal(t, int64(0), row.Count)
|
|
case row.Date.Before(firstStatusChange):
|
|
if row.Status == stc.initialStatus {
|
|
require.Equal(t, int64(1), row.Count)
|
|
} else if row.Status == stc.targetStatus {
|
|
require.Equal(t, int64(0), row.Count)
|
|
}
|
|
case !row.Date.After(tc.reportUntil):
|
|
if row.Status == stc.initialStatus {
|
|
require.Equal(t, int64(0), row.Count)
|
|
} else if row.Status == stc.targetStatus {
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
default:
|
|
t.Errorf("date %q beyond expected range end %q", row.Date, tc.reportUntil)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Two Users/One Transition", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type transition struct {
|
|
from database.UserStatus
|
|
to database.UserStatus
|
|
}
|
|
|
|
type testCase struct {
|
|
name string
|
|
user1Transition transition
|
|
user2Transition transition
|
|
}
|
|
|
|
subTestCases := []testCase{
|
|
{
|
|
name: "Active->Dormant and Dormant->Suspended",
|
|
user1Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
},
|
|
{
|
|
name: "Suspended->Active and Active->Dormant",
|
|
user1Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusActive,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant->Active and Suspended->Dormant",
|
|
user1Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusActive,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusDormant,
|
|
},
|
|
},
|
|
{
|
|
name: "Active->Suspended and Suspended->Active",
|
|
user1Transition: transition{
|
|
from: database.UserStatusActive,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusSuspended,
|
|
to: database.UserStatusActive,
|
|
},
|
|
},
|
|
{
|
|
name: "Dormant->Suspended and Dormant->Active",
|
|
user1Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusSuspended,
|
|
},
|
|
user2Transition: transition{
|
|
from: database.UserStatusDormant,
|
|
to: database.UserStatusActive,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, stc := range subTestCases {
|
|
t.Run(stc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user1 := dbgen.User(t, db, database.User{
|
|
Status: stc.user1Transition.from,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
user2 := dbgen.User(t, db, database.User{
|
|
Status: stc.user2Transition.from,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
user1, err := db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user1.ID,
|
|
Status: stc.user1Transition.to,
|
|
UpdatedAt: firstStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
user2, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
|
ID: user2.ID,
|
|
Status: stc.user2Transition.to,
|
|
UpdatedAt: secondStatusChange,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, userStatusChanges)
|
|
gotCounts := map[time.Time]map[database.UserStatus]int64{}
|
|
for _, row := range userStatusChanges {
|
|
dateInLocation := row.Date.In(tc.location)
|
|
if gotCounts[dateInLocation] == nil {
|
|
gotCounts[dateInLocation] = map[database.UserStatus]int64{}
|
|
}
|
|
gotCounts[dateInLocation][row.Status] = row.Count
|
|
}
|
|
|
|
expectedCounts := map[time.Time]map[database.UserStatus]int64{}
|
|
for d := dbtime.StartOfDay(userCreatedAt); !d.After(dbtime.StartOfDay(tc.reportUntil)); d = d.AddDate(0, 0, 1) {
|
|
expectedCounts[d] = map[database.UserStatus]int64{}
|
|
|
|
// Default values
|
|
expectedCounts[d][stc.user1Transition.from] = 0
|
|
expectedCounts[d][stc.user1Transition.to] = 0
|
|
expectedCounts[d][stc.user2Transition.from] = 0
|
|
expectedCounts[d][stc.user2Transition.to] = 0
|
|
|
|
// Counted Values
|
|
switch {
|
|
case d.Before(userCreatedAt):
|
|
continue
|
|
case d.Before(firstStatusChange):
|
|
expectedCounts[d][stc.user1Transition.from]++
|
|
expectedCounts[d][stc.user2Transition.from]++
|
|
case d.Before(secondStatusChange):
|
|
expectedCounts[d][stc.user1Transition.to]++
|
|
expectedCounts[d][stc.user2Transition.from]++
|
|
case !d.After(tc.reportUntil):
|
|
expectedCounts[d][stc.user1Transition.to]++
|
|
expectedCounts[d][stc.user2Transition.to]++
|
|
default:
|
|
t.Fatalf("date %q beyond expected range end %q", d, tc.reportUntil)
|
|
}
|
|
}
|
|
|
|
require.Equal(t, expectedCounts, gotCounts)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("User precedes and survives query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
_ = dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt.Add(time.Hour * 24)),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, row := range userStatusChanges {
|
|
require.True(
|
|
t,
|
|
row.Date.In(tc.location).Equal(dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, 1+i)),
|
|
"expected date %s, but got %s for row %n",
|
|
dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, 1+i),
|
|
row.Date.In(tc.location).String(),
|
|
i,
|
|
)
|
|
require.Equal(t, database.UserStatusActive, row.Status)
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
})
|
|
|
|
t.Run("User deleted before query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
err := db.UpdateUserDeletedByID(ctx, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = sqlDB.ExecContext(ctx, "UPDATE user_deleted SET deleted_at = $1 WHERE user_id = $2", tc.reportUntil, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: tc.reportUntil.Add(time.Hour * 24),
|
|
EndTime: tc.reportUntil.Add(time.Hour * 48),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, userStatusChanges)
|
|
})
|
|
|
|
t.Run("User deleted during query range", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
Status: database.UserStatusActive,
|
|
CreatedAt: userCreatedAt,
|
|
UpdatedAt: userCreatedAt,
|
|
})
|
|
|
|
err := db.UpdateUserDeletedByID(ctx, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
_, err = sqlDB.ExecContext(ctx, "UPDATE user_deleted SET deleted_at = $1 WHERE user_id = $2", tc.reportUntil, user.ID)
|
|
require.NoError(t, err)
|
|
|
|
userStatusChanges, err := db.GetUserStatusCounts(ctx, database.GetUserStatusCountsParams{
|
|
Tz: tc.timezone,
|
|
StartTime: dbtime.StartOfDay(userCreatedAt),
|
|
EndTime: dbtime.StartOfDay(tc.reportUntil.Add(time.Hour * 24)),
|
|
})
|
|
require.NoError(t, err)
|
|
for i, row := range userStatusChanges {
|
|
row.Date = row.Date.In(tc.location)
|
|
userStatusChanges[i] = row
|
|
target := dbtime.StartOfDay(userCreatedAt).AddDate(0, 0, i)
|
|
assert.True(
|
|
t,
|
|
row.Date.Equal(target),
|
|
"expected date %s, but got %s for row %n",
|
|
target.String(),
|
|
row.Date.String(),
|
|
i,
|
|
)
|
|
require.Equal(t, database.UserStatusActive, row.Status)
|
|
switch {
|
|
case row.Date.Before(userCreatedAt):
|
|
require.Equal(t, int64(0), row.Count)
|
|
case !row.Date.Before(tc.reportUntil):
|
|
// On or after the deletion date, the user should not be counted.
|
|
require.Equal(t, int64(0), row.Count)
|
|
default:
|
|
require.Equal(t, int64(1), row.Count)
|
|
}
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOrganizationDeleteTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("WorkspaceExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OrganizationID: orgA.Org.ID,
|
|
OwnerID: user.ID,
|
|
}).Do()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 workspaces and 1 templates that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 workspaces")
|
|
require.ErrorContains(t, err, "1 templates")
|
|
})
|
|
|
|
t.Run("TemplateExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgA.Org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 0 workspaces and 1 templates that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "1 templates")
|
|
})
|
|
|
|
t.Run("ProvisionerKeyExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
dbgen.ProvisionerKey(t, db, database.ProvisionerKey{
|
|
OrganizationID: orgA.Org.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 provisioner keys that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "1 provisioner keys")
|
|
})
|
|
|
|
t.Run("GroupExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
dbgen.Group(t, db, database.Group{
|
|
OrganizationID: orgA.Org.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 groups that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 groups")
|
|
})
|
|
|
|
t.Run("MemberExists", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
userA := dbgen.User(t, db, database.User{})
|
|
userB := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 members that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 members")
|
|
})
|
|
|
|
t.Run("UserDeletedButNotRemovedFromOrg", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
orgA := dbfake.Organization(t, db).Do()
|
|
|
|
userA := dbgen.User(t, db, database.User{})
|
|
userB := dbgen.User(t, db, database.User{})
|
|
userC := dbgen.User(t, db, database.User{})
|
|
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userA.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userB.ID,
|
|
})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: orgA.Org.ID,
|
|
UserID: userC.ID,
|
|
})
|
|
|
|
// Delete one of the users but don't remove them from the org
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
db.UpdateUserDeletedByID(ctx, userB.ID)
|
|
|
|
err := db.UpdateOrganizationDeletedByID(ctx, database.UpdateOrganizationDeletedByIDParams{
|
|
UpdatedAt: dbtime.Now(),
|
|
ID: orgA.Org.ID,
|
|
})
|
|
require.Error(t, err)
|
|
// cannot delete organization: organization has 1 members that must be deleted first
|
|
require.ErrorContains(t, err, "cannot delete organization")
|
|
require.ErrorContains(t, err, "has 1 members")
|
|
})
|
|
}
|
|
|
|
type templateVersionWithPreset struct {
|
|
database.TemplateVersion
|
|
preset database.TemplateVersionPreset
|
|
}
|
|
|
|
func createTemplate(t *testing.T, db database.Store, orgID uuid.UUID, userID uuid.UUID) database.Template {
|
|
// create template
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: orgID,
|
|
CreatedBy: userID,
|
|
ActiveVersionID: uuid.New(),
|
|
})
|
|
|
|
return tmpl
|
|
}
|
|
|
|
type tmplVersionOpts struct {
|
|
DesiredInstances int32
|
|
}
|
|
|
|
func createTmplVersionAndPreset(
|
|
t *testing.T,
|
|
db database.Store,
|
|
tmpl database.Template,
|
|
versionID uuid.UUID,
|
|
now time.Time,
|
|
opts *tmplVersionOpts,
|
|
) templateVersionWithPreset {
|
|
// Create template version with corresponding preset and preset prebuild
|
|
tmplVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
ID: versionID,
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tmpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tmpl.OrganizationID,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
CreatedBy: tmpl.CreatedBy,
|
|
})
|
|
desiredInstances := int32(1)
|
|
if opts != nil {
|
|
desiredInstances = opts.DesiredInstances
|
|
}
|
|
preset := dbgen.Preset(t, db, database.InsertPresetParams{
|
|
TemplateVersionID: tmplVersion.ID,
|
|
Name: "preset",
|
|
DesiredInstances: sql.NullInt32{
|
|
Int32: desiredInstances,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
return templateVersionWithPreset{
|
|
TemplateVersion: tmplVersion,
|
|
preset: preset,
|
|
}
|
|
}
|
|
|
|
type createPrebuiltWorkspaceOpts struct {
|
|
failedJob bool
|
|
createdAt time.Time
|
|
readyAgents int
|
|
notReadyAgents int
|
|
}
|
|
|
|
func createPrebuiltWorkspace(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
tmpl database.Template,
|
|
extTmplVersion templateVersionWithPreset,
|
|
orgID uuid.UUID,
|
|
now time.Time,
|
|
opts *createPrebuiltWorkspaceOpts,
|
|
) {
|
|
// Create job with corresponding resource and agent
|
|
jobError := sql.NullString{}
|
|
if opts != nil && opts.failedJob {
|
|
jobError = sql.NullString{String: "failed", Valid: true}
|
|
}
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: orgID,
|
|
|
|
CreatedAt: now.Add(-1 * time.Minute),
|
|
Error: jobError,
|
|
})
|
|
|
|
// create ready agents
|
|
readyAgents := 0
|
|
if opts != nil {
|
|
readyAgents = opts.readyAgents
|
|
}
|
|
for i := 0; i < readyAgents; i++ {
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// create not ready agents
|
|
notReadyAgents := 1
|
|
if opts != nil {
|
|
notReadyAgents = opts.notReadyAgents
|
|
}
|
|
for i := 0; i < notReadyAgents; i++ {
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Create corresponding workspace and workspace build
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: uuid.MustParse("c42fdf75-3097-471c-8c33-fb52454d81c0"),
|
|
OrganizationID: tmpl.OrganizationID,
|
|
TemplateID: tmpl.ID,
|
|
})
|
|
createdAt := now
|
|
if opts != nil {
|
|
createdAt = opts.createdAt
|
|
}
|
|
dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
CreatedAt: createdAt,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: extTmplVersion.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: tmpl.CreatedBy,
|
|
JobID: job.ID,
|
|
TemplateVersionPresetID: uuid.NullUUID{
|
|
UUID: extTmplVersion.preset.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
}
|
|
|
|
func TestWorkspacePrebuildsView(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
type workspacePrebuild struct {
|
|
ID uuid.UUID
|
|
Name string
|
|
CreatedAt time.Time
|
|
Ready bool
|
|
CurrentPresetID uuid.UUID
|
|
}
|
|
getWorkspacePrebuilds := func(sqlDB *sql.DB) []*workspacePrebuild {
|
|
rows, err := sqlDB.Query("SELECT id, name, created_at, ready, current_preset_id FROM workspace_prebuilds")
|
|
require.NoError(t, err)
|
|
defer rows.Close()
|
|
|
|
workspacePrebuilds := make([]*workspacePrebuild, 0)
|
|
for rows.Next() {
|
|
var wp workspacePrebuild
|
|
err := rows.Scan(&wp.ID, &wp.Name, &wp.CreatedAt, &wp.Ready, &wp.CurrentPresetID)
|
|
require.NoError(t, err)
|
|
|
|
workspacePrebuilds = append(workspacePrebuilds, &wp)
|
|
}
|
|
|
|
return workspacePrebuilds
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
readyAgents int
|
|
notReadyAgents int
|
|
expectReady bool
|
|
}{
|
|
{
|
|
name: "one ready agent",
|
|
readyAgents: 1,
|
|
notReadyAgents: 0,
|
|
expectReady: true,
|
|
},
|
|
{
|
|
name: "one not ready agent",
|
|
readyAgents: 0,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
{
|
|
name: "one ready, one not ready",
|
|
readyAgents: 1,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
{
|
|
name: "both ready",
|
|
readyAgents: 2,
|
|
notReadyAgents: 0,
|
|
expectReady: true,
|
|
},
|
|
{
|
|
name: "five ready, one not ready",
|
|
readyAgents: 5,
|
|
notReadyAgents: 1,
|
|
expectReady: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
readyAgents: tc.readyAgents,
|
|
notReadyAgents: tc.notReadyAgents,
|
|
})
|
|
|
|
workspacePrebuilds := getWorkspacePrebuilds(sqlDB)
|
|
require.Len(t, workspacePrebuilds, 1)
|
|
require.Equal(t, tc.expectReady, workspacePrebuilds[0].Ready)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetPresetsBackoff(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
findBackoffByTmplVersionID := func(backoffs []database.GetPresetsBackoffRow, tmplVersionID uuid.UUID) *database.GetPresetsBackoffRow {
|
|
for _, backoff := range backoffs {
|
|
if backoff.TemplateVersionID == tmplVersionID {
|
|
return &backoff
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
t.Run("Single Workspace Build", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Multiple Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV1.preset.ID)
|
|
require.Equal(t, int32(3), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Ignore Inactive Version", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
// Active Version
|
|
tmplV2 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmplV2.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
})
|
|
|
|
t.Run("Multiple Templates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 2)
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl2V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple Templates, Versions and Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3 := createTemplate(t, db, orgID, userID)
|
|
tmpl3V1 := createTmplVersionAndPreset(t, db, tmpl3, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3V2 := createTmplVersionAndPreset(t, db, tmpl3, tmpl3.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 3)
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl2V1.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
}
|
|
{
|
|
backoff := findBackoffByTmplVersionID(backoffs, tmpl3.ActiveVersionID)
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl3.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl3V2.preset.ID)
|
|
require.Equal(t, int32(3), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("No Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("No Failed Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("Last job is successful - no backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 1,
|
|
})
|
|
failedJobOpts := createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
}
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &failedJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("Last 3 jobs are successful - no backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
require.Nil(t, backoffs)
|
|
})
|
|
|
|
t.Run("1 job failed out of 3 - backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-time.Hour))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(1), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("3 job failed out of 5 - backoff", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 3,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute), // within lookback period - counted as failed job
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute), // within lookback period - counted as failed job
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: false,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(2), backoff.NumFailed)
|
|
}
|
|
})
|
|
|
|
t.Run("check LastBuildAt timestamp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 6,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-4 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-0 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-3 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-1 * time.Minute),
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-2 * time.Minute),
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, backoffs, 1)
|
|
{
|
|
backoff := backoffs[0]
|
|
require.Equal(t, backoff.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, backoff.PresetID, tmpl1V1.preset.ID)
|
|
require.Equal(t, int32(5), backoff.NumFailed)
|
|
// make sure LastBuildAt is equal to latest failed build timestamp
|
|
require.Equal(t, 0, now.Compare(backoff.LastBuildAt))
|
|
}
|
|
})
|
|
|
|
t.Run("failed job outside lookback period", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
lookbackPeriod := time.Hour
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, &tmplVersionOpts{
|
|
DesiredInstances: 1,
|
|
})
|
|
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
createdAt: now.Add(-lookbackPeriod - time.Minute), // earlier than lookback period - skipped
|
|
})
|
|
|
|
backoffs, err := db.GetPresetsBackoff(ctx, now.Add(-lookbackPeriod))
|
|
require.NoError(t, err)
|
|
require.Len(t, backoffs, 0)
|
|
})
|
|
}
|
|
|
|
func TestGetPresetsAtFailureLimit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := dbtime.Now()
|
|
hourBefore := now.Add(-time.Hour)
|
|
orgID := uuid.New()
|
|
userID := uuid.New()
|
|
|
|
findPresetByTmplVersionID := func(hardLimitedPresets []database.GetPresetsAtFailureLimitRow, tmplVersionID uuid.UUID) *database.GetPresetsAtFailureLimitRow {
|
|
for _, preset := range hardLimitedPresets {
|
|
if preset.TemplateVersionID == tmplVersionID {
|
|
return &preset
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
// true - build is successful
|
|
// false - build is unsuccessful
|
|
buildSuccesses []bool
|
|
hardLimit int64
|
|
expHitHardLimit bool
|
|
}{
|
|
{
|
|
name: "failed build",
|
|
buildSuccesses: []bool{false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "2 failed builds",
|
|
buildSuccesses: []bool{false, false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "successful build",
|
|
buildSuccesses: []bool{true},
|
|
hardLimit: 1,
|
|
expHitHardLimit: false,
|
|
},
|
|
{
|
|
name: "last build is failed",
|
|
buildSuccesses: []bool{true, true, false},
|
|
hardLimit: 1,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "last build is successful",
|
|
buildSuccesses: []bool{false, false, true},
|
|
hardLimit: 1,
|
|
expHitHardLimit: false,
|
|
},
|
|
{
|
|
name: "last 3 builds are failed - hard limit is reached",
|
|
buildSuccesses: []bool{true, true, false, false, false},
|
|
hardLimit: 3,
|
|
expHitHardLimit: true,
|
|
},
|
|
{
|
|
name: "1 out of 3 last build is successful - hard limit is NOT reached",
|
|
buildSuccesses: []bool{false, false, true, false, false},
|
|
hardLimit: 3,
|
|
expHitHardLimit: false,
|
|
},
|
|
// hardLimit set to zero, implicitly disables the hard limit.
|
|
{
|
|
name: "despite 5 failed builds, the hard limit is not reached because it's disabled.",
|
|
buildSuccesses: []bool{false, false, false, false, false},
|
|
hardLimit: 0,
|
|
expHitHardLimit: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
for idx, buildSuccess := range tc.buildSuccesses {
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: !buildSuccess,
|
|
createdAt: hourBefore.Add(time.Duration(idx) * time.Second),
|
|
})
|
|
}
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, tc.hardLimit)
|
|
require.NoError(t, err)
|
|
|
|
if !tc.expHitHardLimit {
|
|
require.Len(t, hardLimitedPresets, 0)
|
|
return
|
|
}
|
|
|
|
require.Len(t, hardLimitedPresets, 1)
|
|
hardLimitedPreset := hardLimitedPresets[0]
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmplV1.preset.ID)
|
|
})
|
|
}
|
|
|
|
t.Run("Ignore Inactive Version", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl := createTemplate(t, db, orgID, userID)
|
|
tmplV1 := createTmplVersionAndPreset(t, db, tmpl, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
// Active Version
|
|
tmplV2 := createTmplVersionAndPreset(t, db, tmpl, tmpl.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl, tmplV2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 1)
|
|
hardLimitedPreset := hardLimitedPresets[0]
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmplV2.preset.ID)
|
|
})
|
|
|
|
t.Run("Multiple Templates", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 2)
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl1V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl2V1.preset.ID)
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple Templates, Versions and Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl2 := createTemplate(t, db, orgID, userID)
|
|
tmpl2V1 := createTmplVersionAndPreset(t, db, tmpl2, tmpl2.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl2, tmpl2V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3 := createTemplate(t, db, orgID, userID)
|
|
tmpl3V1 := createTmplVersionAndPreset(t, db, tmpl3, uuid.New(), now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V1, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
tmpl3V2 := createTmplVersionAndPreset(t, db, tmpl3, tmpl3.ActiveVersionID, now, nil)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl3, tmpl3V2, orgID, now, &createPrebuiltWorkspaceOpts{
|
|
failedJob: true,
|
|
})
|
|
|
|
hardLimit := int64(2)
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, hardLimit)
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, hardLimitedPresets, 3)
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl1.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl1V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl2.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl2V1.preset.ID)
|
|
}
|
|
{
|
|
hardLimitedPreset := findPresetByTmplVersionID(hardLimitedPresets, tmpl3.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.TemplateVersionID, tmpl3.ActiveVersionID)
|
|
require.Equal(t, hardLimitedPreset.PresetID, tmpl3V2.preset.ID)
|
|
}
|
|
})
|
|
|
|
t.Run("No Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
require.Nil(t, hardLimitedPresets)
|
|
})
|
|
|
|
t.Run("No Failed Workspace Builds", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dbgen.Organization(t, db, database.Organization{
|
|
ID: orgID,
|
|
})
|
|
dbgen.User(t, db, database.User{
|
|
ID: userID,
|
|
})
|
|
|
|
tmpl1 := createTemplate(t, db, orgID, userID)
|
|
tmpl1V1 := createTmplVersionAndPreset(t, db, tmpl1, tmpl1.ActiveVersionID, now, nil)
|
|
successfulJobOpts := createPrebuiltWorkspaceOpts{}
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
createPrebuiltWorkspace(ctx, t, db, tmpl1, tmpl1V1, orgID, now, &successfulJobOpts)
|
|
|
|
hardLimitedPresets, err := db.GetPresetsAtFailureLimit(ctx, 1)
|
|
require.NoError(t, err)
|
|
require.Nil(t, hardLimitedPresets)
|
|
})
|
|
}
|
|
|
|
func TestWorkspaceAgentNameUniqueTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createWorkspaceWithAgent := func(t *testing.T, db database.Store, org database.Organization, agentName string) (database.WorkspaceBuild, database.WorkspaceResource, database.WorkspaceAgent) {
|
|
t.Helper()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
OwnerID: user.ID,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: org.ID,
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
BuildNumber: 1,
|
|
JobID: job.ID,
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: build.JobID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
Name: agentName,
|
|
})
|
|
|
|
return build, resource, agent
|
|
}
|
|
|
|
t.Run("DuplicateNamesInSameWorkspaceResource", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with an agent
|
|
_, resource, _ := createWorkspaceWithAgent(t, db, org, "duplicate-agent")
|
|
|
|
// When: Another agent is created for that workspace with the same name.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "duplicate-agent", // Same name as agent1
|
|
ResourceID: resource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "duplicate-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("DuplicateNamesInSameProvisionerJob", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with an agent
|
|
_, resource, agent := createWorkspaceWithAgent(t, db, org, "duplicate-agent")
|
|
|
|
// When: A child agent is created for that workspace with the same name.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: agent.Name,
|
|
ResourceID: resource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "duplicate-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("DuplicateChildNamesOverMultipleResources", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A workspace with two agents
|
|
_, resource1, agent1 := createWorkspaceWithAgent(t, db, org, "parent-agent-1")
|
|
|
|
resource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: resource1.JobID})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource2.ID,
|
|
Name: "parent-agent-2",
|
|
})
|
|
|
|
// Given: One agent has a child agent
|
|
agent1Child := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ParentID: uuid.NullUUID{Valid: true, UUID: agent1.ID},
|
|
Name: "child-agent",
|
|
ResourceID: resource1.ID,
|
|
})
|
|
|
|
// When: A child agent is inserted for the other parent.
|
|
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
ParentID: uuid.NullUUID{Valid: true, UUID: agent2.ID},
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: agent1Child.Name,
|
|
ResourceID: resource2.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
|
|
// Then: We expect it to fail.
|
|
require.Error(t, err)
|
|
var pqErr *pq.Error
|
|
require.True(t, errors.As(err, &pqErr))
|
|
require.Equal(t, pq.ErrorCode("23505"), pqErr.Code) // unique_violation
|
|
require.Contains(t, pqErr.Message, `workspace agent name "child-agent" already exists in this workspace build`)
|
|
})
|
|
|
|
t.Run("SameNamesInDifferentWorkspaces", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
agentName := "same-name-different-workspace"
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Given: A workspace with an agent
|
|
_, _, agent1 := createWorkspaceWithAgent(t, db, org, agentName)
|
|
require.Equal(t, agentName, agent1.Name)
|
|
|
|
// When: A second workspace is created with an agent having the same name
|
|
_, _, agent2 := createWorkspaceWithAgent(t, db, org, agentName)
|
|
require.Equal(t, agentName, agent2.Name)
|
|
|
|
// Then: We expect there to be different agents with the same name.
|
|
require.NotEqual(t, agent1.ID, agent2.ID)
|
|
require.Equal(t, agent1.Name, agent2.Name)
|
|
})
|
|
|
|
t.Run("NullWorkspaceID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: A resource that does not belong to a workspace build (simulating template import)
|
|
orphanJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
orphanResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: orphanJob.ID,
|
|
})
|
|
|
|
// And this resource has a workspace agent.
|
|
agent1, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "orphan-agent",
|
|
ResourceID: orphanResource.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "orphan-agent", agent1.Name)
|
|
|
|
// When: We created another resource that does not belong to a workspace build.
|
|
orphanJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
orphanResource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: orphanJob2.ID,
|
|
})
|
|
|
|
// Then: We expect to be able to create an agent in this new resource that has the same name.
|
|
agent2, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
|
ID: uuid.New(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
Name: "orphan-agent", // Same name as agent1
|
|
ResourceID: orphanResource2.ID,
|
|
AuthToken: uuid.New(),
|
|
Architecture: "amd64",
|
|
OperatingSystem: "linux",
|
|
APIKeyScope: database.AgentKeyScopeEnumAll,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "orphan-agent", agent2.Name)
|
|
require.NotEqual(t, agent1.ID, agent2.ID)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceAgentsByParentID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NilParentDoesNotReturnAllParentAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Given: A workspace agent
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// When: We attempt to select agents with a null parent id
|
|
agents, err := db.GetWorkspaceAgentsByParentID(ctx, uuid.Nil)
|
|
require.NoError(t, err)
|
|
|
|
// Then: We expect to see no agents.
|
|
require.Len(t, agents, 0)
|
|
})
|
|
}
|
|
|
|
func setupWorkspaceAgentQueryResources(t *testing.T, db database.Store, count int) []database.WorkspaceResource {
|
|
t.Helper()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
resources := make([]database.WorkspaceResource, 0, count)
|
|
for i := 0; i < count; i++ {
|
|
resources = append(resources, dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
}))
|
|
}
|
|
|
|
return resources
|
|
}
|
|
|
|
func markWorkspaceAgentDeleted(ctx context.Context, t *testing.T, sqlDB *sql.DB, agentID uuid.UUID) {
|
|
t.Helper()
|
|
|
|
_, err := sqlDB.ExecContext(ctx, "UPDATE workspace_agents SET deleted = TRUE WHERE id = $1", agentID)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestGetWorkspaceAgentsByInstanceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ReturnsAllMatchingRootAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
olderCreatedAt := dbtime.Now().Add(-time.Hour)
|
|
newerCreatedAt := dbtime.Now()
|
|
|
|
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: olderCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: newerCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 2)
|
|
assert.Equal(t, []uuid.UUID{newerAgent.ID, olderAgent.ID}, []uuid.UUID{agents[0].ID, agents[1].ID})
|
|
})
|
|
|
|
t.Run("ExcludesDeletedAndSubAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
baseCreatedAt := dbtime.Now()
|
|
|
|
rootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: baseCreatedAt.Add(-time.Hour),
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ParentID: uuid.NullUUID{UUID: rootAgent.ID, Valid: true},
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: baseCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
deletedRootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: baseCreatedAt.Add(time.Minute),
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
markWorkspaceAgentDeleted(ctx, t, sqlDB, deletedRootAgent.ID)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
assert.Equal(t, rootAgent.ID, agents[0].ID)
|
|
assert.False(t, agents[0].ParentID.Valid)
|
|
})
|
|
|
|
t.Run("OrdersNewestFirst", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
resources := setupWorkspaceAgentQueryResources(t, db, 2)
|
|
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
|
|
olderCreatedAt := dbtime.Now().Add(-time.Hour)
|
|
newerCreatedAt := dbtime.Now()
|
|
|
|
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[0].ID,
|
|
CreatedAt: olderCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resources[1].ID,
|
|
CreatedAt: newerCreatedAt,
|
|
AuthInstanceID: sql.NullString{
|
|
String: authInstanceID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 2)
|
|
assert.Equal(t, newerAgent.ID, agents[0].ID)
|
|
assert.Equal(t, olderAgent.ID, agents[1].ID)
|
|
})
|
|
}
|
|
|
|
func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) {
|
|
t.Helper()
|
|
require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg)
|
|
}
|
|
|
|
// TestGetRunningPrebuiltWorkspaces ensures the correct behavior of the
|
|
// GetRunningPrebuiltWorkspaces query.
|
|
func TestGetRunningPrebuiltWorkspaces(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _ := dbtestutil.NewDB(t)
|
|
now := dbtime.Now()
|
|
|
|
// Given: a prebuilt workspace with a successful start build and a stop build.
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
preset := dbgen.Preset(t, db, database.InsertPresetParams{
|
|
TemplateVersionID: templateVersion.ID,
|
|
DesiredInstances: sql.NullInt32{Int32: 1, Valid: true},
|
|
})
|
|
|
|
setupFixture := func(t *testing.T, db database.Store, name string, deleted bool, transition database.WorkspaceTransition, jobStatus database.ProvisionerJobStatus) database.WorkspaceTable {
|
|
t.Helper()
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: database.PrebuildsSystemUserID,
|
|
TemplateID: template.ID,
|
|
Name: name,
|
|
Deleted: deleted,
|
|
})
|
|
var canceledAt sql.NullTime
|
|
var jobError sql.NullString
|
|
switch jobStatus {
|
|
case database.ProvisionerJobStatusFailed:
|
|
jobError = sql.NullString{String: assert.AnError.Error(), Valid: true}
|
|
case database.ProvisionerJobStatusCanceled:
|
|
canceledAt = sql.NullTime{Time: now, Valid: true}
|
|
}
|
|
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StartedAt: sql.NullTime{Time: now.Add(-time.Minute), Valid: true},
|
|
CanceledAt: canceledAt,
|
|
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
|
Error: jobError,
|
|
})
|
|
wb := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
TemplateVersionPresetID: uuid.NullUUID{UUID: preset.ID, Valid: true},
|
|
JobID: pj.ID,
|
|
BuildNumber: 1,
|
|
Transition: transition,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Reason: database.BuildReasonInitiator,
|
|
})
|
|
// Ensure things are set up as expectd
|
|
require.Equal(t, transition, wb.Transition)
|
|
require.Equal(t, int32(1), wb.BuildNumber)
|
|
require.Equal(t, jobStatus, pj.JobStatus)
|
|
require.Equal(t, deleted, ws.Deleted)
|
|
|
|
return ws
|
|
}
|
|
|
|
// Given: a number of prebuild workspaces with different states exist.
|
|
runningPrebuild := setupFixture(t, db, "running-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded)
|
|
_ = setupFixture(t, db, "stopped-prebuild", false, database.WorkspaceTransitionStop, database.ProvisionerJobStatusSucceeded)
|
|
_ = setupFixture(t, db, "failed-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusFailed)
|
|
_ = setupFixture(t, db, "canceled-prebuild", false, database.WorkspaceTransitionStart, database.ProvisionerJobStatusCanceled)
|
|
_ = setupFixture(t, db, "deleted-prebuild", true, database.WorkspaceTransitionStart, database.ProvisionerJobStatusSucceeded)
|
|
|
|
// Given: a regular workspace also exists.
|
|
_ = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Name: "test-running-regular-workspace",
|
|
Deleted: false,
|
|
})
|
|
|
|
// When: we query for running prebuild workspaces
|
|
runningPrebuilds, err := db.GetRunningPrebuiltWorkspaces(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Then: only the running prebuild workspace should be returned.
|
|
require.Len(t, runningPrebuilds, 1, "expected only one running prebuilt workspace")
|
|
require.Equal(t, runningPrebuild.ID, runningPrebuilds[0].ID, "expected the running prebuilt workspace to be returned")
|
|
}
|
|
|
|
func TestUserSecretsCRUDOperations(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use raw database without dbauthz wrapper for this test
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
t.Run("FullCRUDWorkflow", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create a new user for this test
|
|
testUser := dbgen.User(t, db, database.User{})
|
|
|
|
// 1. CREATE
|
|
secretID := uuid.New()
|
|
createParams := database.CreateUserSecretParams{
|
|
ID: secretID,
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
Description: "Secret for full CRUD workflow",
|
|
Value: "workflow-value",
|
|
EnvName: "WORKFLOW_ENV",
|
|
FilePath: "/workflow/path",
|
|
}
|
|
|
|
createdSecret, err := db.CreateUserSecret(ctx, createParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, secretID, createdSecret.ID)
|
|
|
|
// 2. READ by UserID and Name
|
|
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
}
|
|
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
|
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
|
|
|
// 3. LIST (metadata only)
|
|
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, secrets, 1)
|
|
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
|
|
|
// 4. LIST with values
|
|
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, secretsWithValues, 1)
|
|
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
|
|
|
// 5. UPDATE (partial - only description)
|
|
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
UpdateDescription: true,
|
|
Description: "Updated workflow description",
|
|
}
|
|
|
|
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
|
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
|
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
|
|
|
// 6. DELETE
|
|
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID,
|
|
Name: "workflow-secret",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Verify deletion
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no rows in result set")
|
|
|
|
// Verify list is empty
|
|
secrets, err = db.ListUserSecrets(ctx, testUser.ID)
|
|
require.NoError(t, err)
|
|
assert.Len(t, secrets, 0)
|
|
})
|
|
|
|
t.Run("UniqueConstraints", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create a new user for this test
|
|
testUser := dbgen.User(t, db, database.User{})
|
|
|
|
// Create first secret
|
|
secret1 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test",
|
|
Description: "First secret",
|
|
Value: "value1",
|
|
EnvName: "UNIQUE_ENV",
|
|
FilePath: "/unique/path",
|
|
})
|
|
|
|
// Try to create another secret with the same name (should fail)
|
|
_, err := db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test", // Same name
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Try to create another secret with the same env_name (should fail)
|
|
_, err = db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-2",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
EnvName: "UNIQUE_ENV", // Same env_name
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Try to create another secret with the same file_path (should fail)
|
|
_, err = db.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-3",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
FilePath: "/unique/path", // Same file_path
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "duplicate key value")
|
|
|
|
// Create secret with empty env_name and file_path (should succeed)
|
|
secret2 := dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: testUser.ID,
|
|
Name: "unique-test-4",
|
|
Description: "Second secret",
|
|
Value: "value2",
|
|
EnvName: "", // Empty env_name
|
|
FilePath: "", // Empty file_path
|
|
})
|
|
|
|
// Verify both secrets exist
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID, Name: secret1.Name,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: testUser.ID, Name: secret2.Name,
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestUserSecretsAuthorization(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use raw database and wrap with dbauthz for authorization testing
|
|
db, _ := dbtestutil.NewDB(t)
|
|
authorizer := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
|
authDB := dbauthz.New(db, authorizer, slogtest.Make(t, &slogtest.Options{}), coderdtest.AccessControlStorePointer())
|
|
|
|
// Create test users
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
orgAdmin := dbgen.User(t, db, database.User{})
|
|
|
|
// Create organization for org-scoped roles
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
|
|
// Create secrets for users
|
|
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: user1.ID,
|
|
Name: "user1-secret",
|
|
Description: "User 1's secret",
|
|
Value: "user1-value",
|
|
})
|
|
|
|
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
|
UserID: user2.ID,
|
|
Name: "user2-secret",
|
|
Description: "User 2's secret",
|
|
Value: "user2-value",
|
|
})
|
|
|
|
testCases := []struct {
|
|
name string
|
|
subject rbac.Subject
|
|
lookupUserID uuid.UUID
|
|
lookupName string
|
|
expectedAccess bool
|
|
}{
|
|
{
|
|
name: "UserCanAccessOwnSecrets",
|
|
subject: rbac.Subject{
|
|
ID: user1.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: true,
|
|
},
|
|
{
|
|
name: "UserCannotAccessOtherUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: user1.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user2.ID,
|
|
lookupName: "user2-secret",
|
|
expectedAccess: false,
|
|
},
|
|
{
|
|
name: "OwnerCannotAccessUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: owner.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: false,
|
|
},
|
|
{
|
|
name: "OrgAdminCannotAccessUserSecrets",
|
|
subject: rbac.Subject{
|
|
ID: orgAdmin.ID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
|
Scope: rbac.ScopeAll,
|
|
},
|
|
lookupUserID: user1.ID,
|
|
lookupName: "user1-secret",
|
|
expectedAccess: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
authCtx := dbauthz.As(ctx, tc.subject)
|
|
|
|
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
|
UserID: tc.lookupUserID,
|
|
Name: tc.lookupName,
|
|
})
|
|
|
|
if tc.expectedAccess {
|
|
require.NoError(t, err, "expected access to be granted")
|
|
} else {
|
|
require.Error(t, err, "expected access to be denied")
|
|
assert.True(t, dbauthz.IsNotAuthorizedError(err), "expected authorization error")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWorkspaceBuildDeadlineConstraint(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Name: "test-workspace",
|
|
Deleted: false,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
InitiatorID: database.PrebuildsSystemUserID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true},
|
|
CompletedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
|
})
|
|
workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
JobID: job.ID,
|
|
BuildNumber: 1,
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
deadline time.Time
|
|
maxDeadline time.Time
|
|
expectOK bool
|
|
}{
|
|
{
|
|
name: "no deadline or max_deadline",
|
|
deadline: time.Time{},
|
|
maxDeadline: time.Time{},
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline set when max_deadline is not set",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Time{},
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline before max_deadline",
|
|
deadline: time.Now().Add(-time.Hour),
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: true,
|
|
},
|
|
{
|
|
name: "deadline is max_deadline",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: true,
|
|
},
|
|
|
|
{
|
|
name: "deadline after max_deadline",
|
|
deadline: time.Now().Add(time.Hour),
|
|
maxDeadline: time.Now().Add(-time.Hour),
|
|
expectOK: false,
|
|
},
|
|
{
|
|
name: "deadline is not set when max_deadline is set",
|
|
deadline: time.Time{},
|
|
maxDeadline: time.Now().Add(time.Hour),
|
|
expectOK: false,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
err := db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{
|
|
ID: workspaceBuild.ID,
|
|
Deadline: c.deadline,
|
|
MaxDeadline: c.maxDeadline,
|
|
UpdatedAt: time.Now(),
|
|
})
|
|
if c.expectOK {
|
|
require.NoError(t, err)
|
|
} else {
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckWorkspaceBuildsDeadlineBelowMaxDeadline))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWorkspaceACLObjectConstraint(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user.ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
Deleted: false,
|
|
})
|
|
|
|
t.Run("GroupACLNull", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilACL database.WorkspaceACL
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: nilACL,
|
|
UserACL: database.WorkspaceACL{},
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckGroupAclIsObject))
|
|
})
|
|
|
|
t.Run("UserACLNull", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilACL database.WorkspaceACL
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: database.WorkspaceACL{},
|
|
UserACL: nilACL,
|
|
})
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckUserAclIsObject))
|
|
})
|
|
|
|
t.Run("ValidEmptyObjects", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
err := db.UpdateWorkspaceACLByID(ctx, database.UpdateWorkspaceACLByIDParams{
|
|
ID: workspace.ID,
|
|
GroupACL: database.WorkspaceACL{},
|
|
UserACL: database.WorkspaceACL{},
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
// TestGetLatestWorkspaceBuildsByWorkspaceIDs populates the database with
|
|
// workspaces and builds. It then tests that
|
|
// GetLatestWorkspaceBuildsByWorkspaceIDs returns the latest build for some
|
|
// subset of the workspaces.
|
|
func TestGetLatestWorkspaceBuildsByWorkspaceIDs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
admin := dbgen.User(t, db, database.User{})
|
|
|
|
tv := dbfake.TemplateVersion(t, db).
|
|
Seed(database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: admin.ID,
|
|
}).
|
|
Do()
|
|
|
|
users := make([]database.User, 5)
|
|
wrks := make([][]database.WorkspaceTable, len(users))
|
|
exp := make(map[uuid.UUID]database.WorkspaceBuild)
|
|
for i := range users {
|
|
users[i] = dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
UserID: users[i].ID,
|
|
OrganizationID: org.ID,
|
|
})
|
|
|
|
// Each user gets 2 workspaces.
|
|
wrks[i] = make([]database.WorkspaceTable, 2)
|
|
for wi := range wrks[i] {
|
|
wrks[i][wi] = dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
TemplateID: tv.Template.ID,
|
|
OwnerID: users[i].ID,
|
|
})
|
|
|
|
// Choose a deterministic number of builds per workspace
|
|
// No more than 5 builds though, that would be excessive.
|
|
for j := int32(1); int(j) <= (i+wi)%5; j++ {
|
|
wb := dbfake.WorkspaceBuild(t, db, wrks[i][wi]).
|
|
Seed(database.WorkspaceBuild{
|
|
WorkspaceID: wrks[i][wi].ID,
|
|
BuildNumber: j + 1,
|
|
}).
|
|
Do()
|
|
|
|
exp[wrks[i][wi].ID] = wb.Build // Save the final workspace build
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only take half the users. And only take 1 workspace per user for the test.
|
|
// The others are just noice. This just queries a subset of workspaces and builds
|
|
// to make sure the noise doesn't interfere with the results.
|
|
assertWrks := wrks[:len(users)/2]
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ids := slice.Convert[[]database.WorkspaceTable, uuid.UUID](assertWrks, func(pair []database.WorkspaceTable) uuid.UUID {
|
|
return pair[0].ID
|
|
})
|
|
|
|
require.Greater(t, len(ids), 0, "expected some workspace ids for test")
|
|
builds, err := db.GetLatestWorkspaceBuildsByWorkspaceIDs(ctx, ids)
|
|
require.NoError(t, err)
|
|
for _, b := range builds {
|
|
expB, ok := exp[b.WorkspaceID]
|
|
require.Truef(t, ok, "unexpected workspace build for workspace id %s", b.WorkspaceID)
|
|
require.Equalf(t, expB.ID, b.ID, "unexpected workspace build id for workspace id %s", b.WorkspaceID)
|
|
require.Equal(t, expB.BuildNumber, b.BuildNumber, "unexpected build number")
|
|
}
|
|
}
|
|
|
|
func TestTasksWithStatusView(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createProvisionerJob := func(t *testing.T, db database.Store, org database.Organization, user database.User, buildStatus database.ProvisionerJobStatus) database.ProvisionerJob {
|
|
t.Helper()
|
|
|
|
var jobParams database.ProvisionerJob
|
|
|
|
switch buildStatus {
|
|
case database.ProvisionerJobStatusPending:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
}
|
|
case database.ProvisionerJobStatusRunning:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusFailed:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
Error: sql.NullString{Valid: true, String: "job failed"},
|
|
}
|
|
case database.ProvisionerJobStatusSucceeded:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusCanceling:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
case database.ProvisionerJobStatusCanceled:
|
|
jobParams = database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: user.ID,
|
|
StartedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
CanceledAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
|
|
}
|
|
default:
|
|
t.Errorf("invalid build status: %v", buildStatus)
|
|
}
|
|
|
|
return dbgen.ProvisionerJob(t, db, nil, jobParams)
|
|
}
|
|
|
|
createTask := func(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
db database.Store,
|
|
org database.Organization,
|
|
user database.User,
|
|
buildStatus database.ProvisionerJobStatus,
|
|
buildTransition database.WorkspaceTransition,
|
|
agentState database.WorkspaceAgentLifecycleState,
|
|
appHealths []database.WorkspaceAppHealth,
|
|
) database.Task {
|
|
t.Helper()
|
|
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
if buildStatus == "" {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
}
|
|
|
|
job := createProvisionerJob(t, db, org, user, buildStatus)
|
|
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: template.ID,
|
|
OwnerID: user.ID,
|
|
})
|
|
workspaceID := uuid.NullUUID{Valid: true, UUID: workspace.ID}
|
|
|
|
task := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
WorkspaceID: workspaceID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
|
|
workspaceBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
BuildNumber: 1,
|
|
Transition: buildTransition,
|
|
InitiatorID: user.ID,
|
|
JobID: job.ID,
|
|
})
|
|
workspaceBuildNumber := workspaceBuild.BuildNumber
|
|
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
if agentState != "" {
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
})
|
|
workspaceAgentID := agent.ID
|
|
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{UUID: workspaceAgentID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
|
ID: agent.ID,
|
|
LifecycleState: agentState,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i, health := range appHealths {
|
|
app := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
|
|
AgentID: workspaceAgentID,
|
|
Slug: fmt.Sprintf("test-app-%d", i),
|
|
DisplayName: fmt.Sprintf("Test App %d", i+1),
|
|
Health: health,
|
|
})
|
|
if i == 0 {
|
|
// Assume the first app is the tasks app.
|
|
_, err := db.UpsertTaskWorkspaceApp(ctx, database.UpsertTaskWorkspaceAppParams{
|
|
TaskID: task.ID,
|
|
WorkspaceBuildNumber: workspaceBuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{UUID: workspaceAgentID, Valid: true},
|
|
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return task
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
buildStatus database.ProvisionerJobStatus
|
|
buildTransition database.WorkspaceTransition
|
|
agentState database.WorkspaceAgentLifecycleState
|
|
appHealths []database.WorkspaceAppHealth
|
|
expectedStatus database.TaskStatus
|
|
description string
|
|
expectBuildNumberValid bool
|
|
expectBuildNumber int32
|
|
expectWorkspaceAgentValid bool
|
|
expectWorkspaceAppValid bool
|
|
}{
|
|
{
|
|
name: "NoWorkspace",
|
|
expectedStatus: "pending",
|
|
description: "Task with no workspace assigned",
|
|
expectBuildNumberValid: false,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "FailedBuild",
|
|
buildStatus: database.ProvisionerJobStatusFailed,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build failed",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "CancelingBuild",
|
|
buildStatus: database.ProvisionerJobStatusCanceling,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build is canceling",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "CanceledBuild",
|
|
buildStatus: database.ProvisionerJobStatusCanceled,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Latest workspace build was canceled",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "StoppedWorkspace",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStop,
|
|
expectedStatus: database.TaskStatusPaused,
|
|
description: "Workspace is stopped",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "DeletedWorkspace",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionDelete,
|
|
expectedStatus: database.TaskStatusPaused,
|
|
description: "Workspace is deleted",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "PendingStart",
|
|
buildStatus: database.ProvisionerJobStatusPending,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusPending,
|
|
description: "Workspace build pending (not yet picked up by provisioner)",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "RunningStart",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace build is starting (running)",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: false,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "StartingAgent",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStarting,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace is running but agent is starting",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "CreatedAgent",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateCreated,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Workspace is running but agent is created",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentInitializingApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Agent is ready but app is initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentHealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent is ready and app is healthy",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentDisabledApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthDisabled},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent is ready and app health checking is disabled",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "ReadyAgentUnhealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Agent is ready but app is unhealthy",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentStartTimeout",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStartTimeout,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent start timed out but app is healthy, defer to app",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentStartError",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStartError,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Agent start failed but app is healthy, defer to app",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "AgentShuttingDown",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateShuttingDown,
|
|
expectedStatus: database.TaskStatusUnknown,
|
|
description: "Agent is shutting down",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "AgentOff",
|
|
buildStatus: database.ProvisionerJobStatusSucceeded,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateOff,
|
|
expectedStatus: database.TaskStatusUnknown,
|
|
description: "Agent is off",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: false,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentHealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and healthy app should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentInitializingApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Running job with ready agent but initializing app should be initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentUnhealthyApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusError,
|
|
description: "Running job with ready agent but unhealthy app should be error",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobConnectingAgent",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateStarting,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthInitializing},
|
|
expectedStatus: database.TaskStatusInitializing,
|
|
description: "Running job with connecting agent should be initializing",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentDisabledApp",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthDisabled},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and disabled app health checking should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
{
|
|
name: "RunningJobReadyAgentHealthyTaskAppUnhealthyOtherAppIsOK",
|
|
buildStatus: database.ProvisionerJobStatusRunning,
|
|
buildTransition: database.WorkspaceTransitionStart,
|
|
agentState: database.WorkspaceAgentLifecycleStateReady,
|
|
appHealths: []database.WorkspaceAppHealth{database.WorkspaceAppHealthHealthy, database.WorkspaceAppHealthUnhealthy},
|
|
expectedStatus: database.TaskStatusActive,
|
|
description: "Running job with ready agent and multiple healthy apps should be active",
|
|
expectBuildNumberValid: true,
|
|
expectBuildNumber: 1,
|
|
expectWorkspaceAgentValid: true,
|
|
expectWorkspaceAppValid: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
task := createTask(ctx, t, db, org, user, tt.buildStatus, tt.buildTransition, tt.agentState, tt.appHealths)
|
|
|
|
got, err := db.GetTaskByID(ctx, task.ID)
|
|
require.NoError(t, err)
|
|
|
|
t.Logf("Task status debug: %s", got.StatusDebug)
|
|
|
|
require.Equal(t, tt.expectedStatus, got.Status)
|
|
|
|
require.Equal(t, tt.expectBuildNumberValid, got.WorkspaceBuildNumber.Valid)
|
|
if tt.expectBuildNumberValid {
|
|
require.Equal(t, tt.expectBuildNumber, got.WorkspaceBuildNumber.Int32)
|
|
}
|
|
|
|
require.Equal(t, tt.expectWorkspaceAgentValid, got.WorkspaceAgentID.Valid)
|
|
if tt.expectWorkspaceAgentValid {
|
|
require.NotEqual(t, uuid.Nil, got.WorkspaceAgentID.UUID)
|
|
}
|
|
|
|
require.Equal(t, tt.expectWorkspaceAppValid, got.WorkspaceAppID.Valid)
|
|
if tt.expectWorkspaceAppValid {
|
|
require.NotEqual(t, uuid.Nil, got.WorkspaceAppID.UUID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetTaskByWorkspaceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupTask func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable)
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "task doesn't exist",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "task with no workspace id",
|
|
setupTask: func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable) {
|
|
dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "task with workspace id",
|
|
setupTask: func(t *testing.T, db database.Store, org database.Organization, user database.User, templateVersion database.TemplateVersion, workspace database.WorkspaceTable) {
|
|
workspaceID := uuid.NullUUID{Valid: true, UUID: workspace.ID}
|
|
dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: "test-task",
|
|
WorkspaceID: workspaceID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
CreatedBy: user.ID,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
|
|
if tt.setupTask != nil {
|
|
tt.setupTask(t, db, org, user, templateVersion, workspace)
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID)
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.False(t, task.WorkspaceBuildNumber.Valid)
|
|
require.False(t, task.WorkspaceAgentID.Valid)
|
|
require.False(t, task.WorkspaceAppID.Valid)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeleteTaskDeletesTaskSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
task := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
|
|
err := db.UpsertTaskSnapshot(ctx, database.UpsertTaskSnapshotParams{
|
|
TaskID: task.ID,
|
|
LogSnapshot: json.RawMessage(`{"messages":[]}`),
|
|
LogSnapshotCreatedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.DeleteTask(ctx, database.DeleteTaskParams{
|
|
ID: task.ID,
|
|
DeletedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.GetTaskSnapshot(ctx, task.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
}
|
|
|
|
func TestTaskNameUniqueness(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{UUID: template.ID, Valid: true},
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user1.ID,
|
|
})
|
|
|
|
taskName := "my-task"
|
|
|
|
// Create initial task for user1.
|
|
task1 := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user1.ID,
|
|
Name: taskName,
|
|
TemplateVersionID: tv.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
require.NotEqual(t, uuid.Nil, task1.ID)
|
|
|
|
tests := []struct {
|
|
name string
|
|
ownerID uuid.UUID
|
|
taskName string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "duplicate task name same user",
|
|
ownerID: user1.ID,
|
|
taskName: taskName,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "duplicate task name different case same user",
|
|
ownerID: user1.ID,
|
|
taskName: "MY-TASK",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "same task name different user",
|
|
ownerID: user2.ID,
|
|
taskName: taskName,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
taskID := uuid.New()
|
|
task, err := db.InsertTask(ctx, database.InsertTaskParams{
|
|
ID: taskID,
|
|
OrganizationID: org.ID,
|
|
OwnerID: tt.ownerID,
|
|
Name: tt.taskName,
|
|
TemplateVersionID: tv.ID,
|
|
TemplateParameters: json.RawMessage("{}"),
|
|
Prompt: "Test prompt",
|
|
CreatedAt: dbtime.Now(),
|
|
})
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, uuid.Nil, task.ID)
|
|
require.NotEqual(t, task1.ID, task.ID)
|
|
require.Equal(t, taskID, task.ID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUsageEventsTrigger(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This is not exposed in the querier interface intentionally.
|
|
getDailyRows := func(ctx context.Context, sqlDB *sql.DB) []database.UsageEventsDaily {
|
|
t.Helper()
|
|
rows, err := sqlDB.QueryContext(ctx, "SELECT day, event_type, usage_data FROM usage_events_daily ORDER BY day ASC")
|
|
require.NoError(t, err, "perform query")
|
|
defer rows.Close()
|
|
|
|
var out []database.UsageEventsDaily
|
|
for rows.Next() {
|
|
var row database.UsageEventsDaily
|
|
err := rows.Scan(&row.Day, &row.EventType, &row.UsageData)
|
|
require.NoError(t, err, "scan row")
|
|
out = append(out, row)
|
|
}
|
|
return out
|
|
}
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Assert there are no daily rows.
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 0)
|
|
|
|
// Insert a usage event.
|
|
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "1",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 41}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Assert there is one daily row that contains the correct data.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 41}`, string(rows[0].UsageData))
|
|
// The read row might be `+0000` rather than `UTC` specifically, so just
|
|
// ensure it's within 1 second of the expected time.
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
|
|
// Insert a new usage event on the same UTC day, should increment the count.
|
|
locSydney, err := time.LoadLocation("Australia/Sydney")
|
|
require.NoError(t, err)
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "2",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 1}`),
|
|
// Insert it at a random point during the same day. Sydney is +1000 or
|
|
// +1100, so 8am in Sydney is the previous day in UTC.
|
|
CreatedAt: time.Date(2025, 1, 2, 8, 38, 57, 0, locSydney),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// There should still be only one daily row with the incremented count.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 42}`, string(rows[0].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
|
|
// TODO: when we have a new event type, we should test that adding an
|
|
// event with a different event type on the same day creates a new daily
|
|
// row.
|
|
|
|
// Insert a new usage event on a different day, should create a new daily
|
|
// row.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "3",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 1}`),
|
|
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// There should now be two daily rows.
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 2)
|
|
// Output is sorted by day ascending, so the first row should be the
|
|
// previous day's row.
|
|
require.Equal(t, "dc_managed_agents_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 42}`, string(rows[0].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), rows[0].Day, time.Second)
|
|
require.Equal(t, "dc_managed_agents_v1", rows[1].EventType)
|
|
require.JSONEq(t, `{"count": 1}`, string(rows[1].UsageData))
|
|
require.WithinDuration(t, time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), rows[1].Day, time.Second)
|
|
})
|
|
|
|
t.Run("HeartbeatAISeats", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Insert a heartbeat event.
|
|
err := db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-1",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 10}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.Equal(t, "hb_ai_seats_v1", rows[0].EventType)
|
|
require.JSONEq(t, `{"count": 10}`, string(rows[0].UsageData))
|
|
|
|
// Insert a higher count on the same day — should take the max.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-2",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 50}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
|
|
// Insert a lower count on the same day — should keep the max (50).
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-3",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 25}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 18, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 1)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
|
|
// Insert on a different day.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "hb-4",
|
|
EventType: "hb_ai_seats_v1",
|
|
EventData: []byte(`{"count": 5}`),
|
|
CreatedAt: time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 2)
|
|
require.JSONEq(t, `{"count": 50}`, string(rows[0].UsageData))
|
|
require.JSONEq(t, `{"count": 5}`, string(rows[1].UsageData))
|
|
|
|
// Also insert a dc_managed_agents_v1 on the same first day to
|
|
// verify different event types get separate daily rows.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "dc-1",
|
|
EventType: "dc_managed_agents_v1",
|
|
EventData: []byte(`{"count": 7}`),
|
|
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rows = getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 3)
|
|
})
|
|
|
|
t.Run("UnknownEventType", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
|
|
|
// Relax the usage_events.event_type check constraint to see what
|
|
// happens when we insert a usage event that the trigger doesn't know
|
|
// about.
|
|
_, err := sqlDB.ExecContext(ctx, "ALTER TABLE usage_events DROP CONSTRAINT usage_event_type_check")
|
|
require.NoError(t, err)
|
|
|
|
// Insert a usage event with an unknown event type.
|
|
err = db.InsertUsageEvent(ctx, database.InsertUsageEventParams{
|
|
ID: "broken",
|
|
EventType: "dean's cool event",
|
|
EventData: []byte(`{"my": "cool json"}`),
|
|
CreatedAt: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
})
|
|
require.ErrorContains(t, err, "Unhandled usage event type in aggregate_usage_event")
|
|
|
|
// The event should've been blocked.
|
|
var count int
|
|
err = sqlDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_events WHERE id = 'broken'").Scan(&count)
|
|
require.NoError(t, err)
|
|
require.Equal(t, 0, count)
|
|
|
|
// We should not have any daily rows.
|
|
rows := getDailyRows(ctx, sqlDB)
|
|
require.Len(t, rows, 0)
|
|
})
|
|
}
|
|
|
|
func TestListTasks(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, ps := dbtestutil.NewDB(t)
|
|
|
|
// Given: two organizations and two users, one of which is a member of both
|
|
org1 := dbgen.Organization(t, db, database.Organization{})
|
|
org2 := dbgen.Organization(t, db, database.Organization{})
|
|
user1 := dbgen.User(t, db, database.User{})
|
|
user2 := dbgen.User(t, db, database.User{})
|
|
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org1.ID,
|
|
UserID: user1.ID,
|
|
})
|
|
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
|
OrganizationID: org2.ID,
|
|
UserID: user2.ID,
|
|
})
|
|
|
|
// Given: a template with an active version
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org1.ID,
|
|
})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
CreatedBy: user1.ID,
|
|
OrganizationID: org1.ID,
|
|
ActiveVersionID: tv.ID,
|
|
})
|
|
|
|
// Helper function to create a task
|
|
createTask := func(orgID, ownerID uuid.UUID) database.Task {
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: orgID,
|
|
OwnerID: ownerID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
pj := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{})
|
|
sidebarAppID := uuid.New()
|
|
wb := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
JobID: pj.ID,
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: ws.ID,
|
|
})
|
|
wr := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: pj.ID,
|
|
})
|
|
agt := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: wr.ID,
|
|
})
|
|
wa := dbgen.WorkspaceApp(t, db, database.WorkspaceApp{
|
|
ID: sidebarAppID,
|
|
AgentID: agt.ID,
|
|
})
|
|
tsk := dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: orgID,
|
|
OwnerID: ownerID,
|
|
Prompt: testutil.GetRandomName(t),
|
|
TemplateVersionID: tv.ID,
|
|
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
|
})
|
|
_ = dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
|
|
TaskID: tsk.ID,
|
|
WorkspaceBuildNumber: wb.BuildNumber,
|
|
WorkspaceAgentID: uuid.NullUUID{Valid: true, UUID: agt.ID},
|
|
WorkspaceAppID: uuid.NullUUID{Valid: true, UUID: wa.ID},
|
|
})
|
|
t.Logf("task_id:%s owner_id:%s org_id:%s", tsk.ID, ownerID, orgID)
|
|
return tsk
|
|
}
|
|
|
|
// Given: user1 has one task, user2 has one task, user3 has two tasks (one in each org)
|
|
task1 := createTask(org1.ID, user1.ID)
|
|
task2 := createTask(org1.ID, user2.ID)
|
|
task3 := createTask(org2.ID, user2.ID)
|
|
|
|
// Then: run various filters and assert expected results
|
|
for _, tc := range []struct {
|
|
name string
|
|
filter database.ListTasksParams
|
|
expectIDs []uuid.UUID
|
|
}{
|
|
{
|
|
name: "no filter",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: uuid.Nil,
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
expectIDs: []uuid.UUID{task3.ID, task2.ID, task1.ID},
|
|
},
|
|
{
|
|
name: "filter by user ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: uuid.Nil,
|
|
},
|
|
expectIDs: []uuid.UUID{task1.ID},
|
|
},
|
|
{
|
|
name: "filter by organization ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: uuid.Nil,
|
|
OrganizationID: org1.ID,
|
|
},
|
|
expectIDs: []uuid.UUID{task2.ID, task1.ID},
|
|
},
|
|
{
|
|
name: "filter by user and organization ID",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user2.ID,
|
|
OrganizationID: org2.ID,
|
|
},
|
|
expectIDs: []uuid.UUID{task3.ID},
|
|
},
|
|
{
|
|
name: "no results",
|
|
filter: database.ListTasksParams{
|
|
OwnerID: user1.ID,
|
|
OrganizationID: org2.ID,
|
|
},
|
|
expectIDs: nil,
|
|
},
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
tasks, err := db.ListTasks(ctx, tc.filter)
|
|
require.NoError(t, err)
|
|
require.Len(t, tasks, len(tc.expectIDs))
|
|
|
|
for idx, eid := range tc.expectIDs {
|
|
task := tasks[idx]
|
|
assert.Equal(t, eid, task.ID, "task ID mismatch at index %d", idx)
|
|
|
|
require.True(t, task.WorkspaceBuildNumber.Valid)
|
|
require.Greater(t, task.WorkspaceBuildNumber.Int32, int32(0))
|
|
require.True(t, task.WorkspaceAgentID.Valid)
|
|
require.NotEqual(t, uuid.Nil, task.WorkspaceAgentID.UUID)
|
|
require.True(t, task.WorkspaceAppID.Valid)
|
|
require.NotEqual(t, uuid.Nil, task.WorkspaceAppID.UUID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateTaskWorkspaceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Create organization, users, template, and template version.
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{Valid: true, UUID: template.ID},
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
// Create another template for mismatch test.
|
|
template2 := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupTask func(t *testing.T) database.Task
|
|
setupWS func(t *testing.T) database.WorkspaceTable
|
|
wantErr bool
|
|
wantNoRow bool
|
|
}{
|
|
{
|
|
name: "successful update with matching template",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: false,
|
|
},
|
|
{
|
|
name: "task already has workspace_id",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
existingWS := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{Valid: true, UUID: existingWS.ID},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true, // No row should be returned because WHERE condition fails.
|
|
},
|
|
{
|
|
name: "template mismatch between task and workspace",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{}, // NULL workspace_id
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template2.ID, // Different template, JOIN will fail.
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true, // No row should be returned because JOIN condition fails.
|
|
},
|
|
{
|
|
name: "task does not exist",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return database.Task{
|
|
ID: uuid.New(), // Non-existent task ID.
|
|
}
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
TemplateID: template.ID,
|
|
})
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true,
|
|
},
|
|
{
|
|
name: "workspace does not exist",
|
|
setupTask: func(t *testing.T) database.Task {
|
|
return dbgen.Task(t, db, database.TaskTable{
|
|
OrganizationID: org.ID,
|
|
OwnerID: user.ID,
|
|
Name: testutil.GetRandomName(t),
|
|
WorkspaceID: uuid.NullUUID{},
|
|
TemplateVersionID: templateVersion.ID,
|
|
Prompt: "Test prompt",
|
|
})
|
|
},
|
|
setupWS: func(t *testing.T) database.WorkspaceTable {
|
|
return database.WorkspaceTable{
|
|
ID: uuid.New(), // Non-existent workspace ID.
|
|
}
|
|
},
|
|
wantErr: false,
|
|
wantNoRow: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
task := tt.setupTask(t)
|
|
workspace := tt.setupWS(t)
|
|
|
|
updatedTask, err := db.UpdateTaskWorkspaceID(ctx, database.UpdateTaskWorkspaceIDParams{
|
|
ID: task.ID,
|
|
WorkspaceID: uuid.NullUUID{Valid: true, UUID: workspace.ID},
|
|
})
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
return
|
|
}
|
|
|
|
if tt.wantNoRow {
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
require.Equal(t, task.ID, updatedTask.ID)
|
|
require.True(t, updatedTask.WorkspaceID.Valid)
|
|
require.Equal(t, workspace.ID, updatedTask.WorkspaceID.UUID)
|
|
require.Equal(t, task.OrganizationID, updatedTask.OrganizationID)
|
|
require.Equal(t, task.OwnerID, updatedTask.OwnerID)
|
|
require.Equal(t, task.Name, updatedTask.Name)
|
|
require.Equal(t, task.TemplateVersionID, updatedTask.TemplateVersionID)
|
|
|
|
// Verify the update persisted by fetching the task again.
|
|
fetchedTask, err := db.GetTaskByID(ctx, task.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, fetchedTask.WorkspaceID.Valid)
|
|
require.Equal(t, workspace.ID, fetchedTask.WorkspaceID.UUID)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
t.Run("NonExistingInterception", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
got, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: uuid.New(),
|
|
EndedAt: time.Now(),
|
|
})
|
|
require.ErrorContains(t, err, "no rows in result set")
|
|
require.EqualValues(t, database.AIBridgeInterception{}, got)
|
|
})
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
interceptions := []database.AIBridgeInterception{}
|
|
|
|
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
|
insertParams := database.InsertAIBridgeInterceptionParams{
|
|
ID: uid,
|
|
InitiatorID: user.ID,
|
|
Metadata: json.RawMessage("{}"),
|
|
Client: sql.NullString{String: "client", Valid: true},
|
|
CredentialKind: database.CredentialKindCentralized,
|
|
}
|
|
|
|
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uid, intc.ID)
|
|
require.False(t, intc.EndedAt.Valid)
|
|
require.True(t, intc.Client.Valid)
|
|
require.Equal(t, "client", intc.Client.String)
|
|
interceptions = append(interceptions, intc)
|
|
}
|
|
|
|
intc0 := interceptions[0]
|
|
endedAt := time.Now()
|
|
// Mark first interception as done
|
|
updated, err := db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: intc0.ID,
|
|
EndedAt: endedAt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, updated.ID, intc0.ID)
|
|
require.True(t, updated.EndedAt.Valid)
|
|
require.WithinDuration(t, endedAt, updated.EndedAt.Time, 5*time.Second)
|
|
|
|
// Updating first interception again should fail
|
|
updated, err = db.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
|
ID: intc0.ID,
|
|
EndedAt: endedAt.Add(time.Hour),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
// Other interceptions should not have ended_at set
|
|
for _, intc := range interceptions[1:] {
|
|
got, err := db.GetAIBridgeInterceptionByID(ctx, intc.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, got.EndedAt.Valid)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDeleteExpiredAPIKeys(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// Constant time for testing
|
|
now := time.Date(2025, 11, 20, 12, 0, 0, 0, time.UTC)
|
|
expiredBefore := now.Add(-time.Hour) // Anything before this is expired
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
expiredTimes := []time.Time{
|
|
expiredBefore.Add(-time.Hour * 24 * 365),
|
|
expiredBefore.Add(-time.Hour * 24),
|
|
expiredBefore.Add(-time.Hour),
|
|
expiredBefore.Add(-time.Minute),
|
|
expiredBefore.Add(-time.Second),
|
|
}
|
|
for _, exp := range expiredTimes {
|
|
// Expired api keys
|
|
dbgen.APIKey(t, db, database.APIKey{UserID: user.ID, ExpiresAt: exp})
|
|
}
|
|
|
|
unexpiredTimes := []time.Time{
|
|
expiredBefore.Add(time.Hour * 24 * 365),
|
|
expiredBefore.Add(time.Hour * 24),
|
|
expiredBefore.Add(time.Hour),
|
|
expiredBefore.Add(time.Minute),
|
|
expiredBefore.Add(time.Second),
|
|
}
|
|
for _, unexp := range unexpiredTimes {
|
|
// Unexpired api keys
|
|
dbgen.APIKey(t, db, database.APIKey{UserID: user.ID, ExpiresAt: unexp})
|
|
}
|
|
|
|
// All keys are present before deletion
|
|
keys, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, keys, len(expiredTimes)+len(unexpiredTimes))
|
|
|
|
// Delete expired keys
|
|
// First verify the limit works by deleting one at a time
|
|
deletedCount, err := db.DeleteExpiredAPIKeys(ctx, database.DeleteExpiredAPIKeysParams{
|
|
Before: expiredBefore,
|
|
LimitCount: 1,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(1), deletedCount)
|
|
|
|
// Ensure it was deleted
|
|
remaining, err := db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, len(expiredTimes)+len(unexpiredTimes)-1)
|
|
|
|
// Delete the rest of the expired keys
|
|
deletedCount, err = db.DeleteExpiredAPIKeys(ctx, database.DeleteExpiredAPIKeysParams{
|
|
Before: expiredBefore,
|
|
LimitCount: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(len(expiredTimes)-1), deletedCount)
|
|
|
|
// Ensure only unexpired keys remain
|
|
remaining, err = db.GetAPIKeysByUserID(ctx, database.GetAPIKeysByUserIDParams{
|
|
LoginType: user.LoginType,
|
|
UserID: user.ID,
|
|
IncludeExpired: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remaining, len(unexpiredTimes))
|
|
}
|
|
|
|
func TestGetAuthenticatedWorkspaceAgentAndBuildByAuthToken_ShutdownScripts(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
tpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
ver := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: tpl.ID,
|
|
Valid: true,
|
|
},
|
|
OrganizationID: tpl.OrganizationID,
|
|
CreatedBy: owner.ID,
|
|
})
|
|
|
|
t.Run("DuringStopBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with succeeded job (already completed).
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build (becomes latest).
|
|
stopJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should still authenticate during stop build execution.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.NoError(t, err, "agent should authenticate during stop build execution")
|
|
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build, not stop build")
|
|
})
|
|
|
|
t.Run("AfterStopJobCompletes", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with completed job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build (becomes latest) with completed job.
|
|
stopJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &stopJob)
|
|
stopJob = dbgen.ProvisionerJob(t, db, nil, stopJob)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should NOT authenticate after stop job completes.
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate after stop job completes")
|
|
})
|
|
|
|
t.Run("FailedStartBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create START build with FAILED job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusFailed, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create STOP build with running job.
|
|
stopJob := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should NOT authenticate (start build failed).
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent from failed start build should not authenticate")
|
|
})
|
|
|
|
t.Run("PendingStopBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create start build with succeeded job.
|
|
startJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob)
|
|
startJob = dbgen.ProvisionerJob(t, db, nil, startJob)
|
|
startResource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob.ID,
|
|
})
|
|
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource.ID,
|
|
})
|
|
|
|
// Create stop build with pending job (not started yet).
|
|
stopJob := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusPending, &stopJob)
|
|
stopJob = dbgen.ProvisionerJob(t, db, nil, stopJob)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob.ID,
|
|
})
|
|
|
|
// Agent should authenticate during pending stop build.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent.AuthToken)
|
|
require.NoError(t, err, "agent should authenticate during pending stop build")
|
|
require.Equal(t, agent.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild.ID, row.WorkspaceBuild.ID, "should return start build")
|
|
})
|
|
|
|
t.Run("MultipleStartStopCycles", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Build 1: START (succeeded).
|
|
startJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob1)
|
|
startJob1 = dbgen.ProvisionerJob(t, db, nil, startJob1)
|
|
startResource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob1.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource1.ID,
|
|
})
|
|
|
|
// Build 2: STOP (succeeded).
|
|
stopJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &stopJob1)
|
|
stopJob1 = dbgen.ProvisionerJob(t, db, nil, stopJob1)
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob1.ID,
|
|
})
|
|
|
|
// Build 3: START (succeeded).
|
|
startJob2 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob2)
|
|
startJob2 = dbgen.ProvisionerJob(t, db, nil, startJob2)
|
|
startResource2 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob2.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
startBuild2 := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 3,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob2.ID,
|
|
})
|
|
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource2.ID,
|
|
})
|
|
|
|
// Build 4: STOP (running).
|
|
stopJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 4,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
InitiatorID: owner.ID,
|
|
JobID: stopJob2.ID,
|
|
})
|
|
|
|
// Agent from build 3 should authenticate.
|
|
row, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent2.AuthToken)
|
|
require.NoError(t, err, "agent from most recent start should authenticate during stop")
|
|
require.Equal(t, agent2.ID, row.WorkspaceAgent.ID)
|
|
require.Equal(t, startBuild2.ID, row.WorkspaceBuild.ID)
|
|
|
|
// Agent from build 1 should NOT authenticate.
|
|
_, err = db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent from old cycle should not authenticate")
|
|
})
|
|
|
|
t.Run("WrongTransitionType", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OwnerID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
TemplateID: tpl.ID,
|
|
})
|
|
|
|
// Create first start build.
|
|
startJob1 := database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
}
|
|
setJobStatus(t, database.ProvisionerJobStatusSucceeded, &startJob1)
|
|
startJob1 = dbgen.ProvisionerJob(t, db, nil, startJob1)
|
|
startResource1 := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: startJob1.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 1,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob1.ID,
|
|
})
|
|
agent1 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: startResource1.ID,
|
|
})
|
|
|
|
// Create another START build as latest (not STOP).
|
|
startJob2 := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
InitiatorID: owner.ID,
|
|
OrganizationID: org.ID,
|
|
JobStatus: database.ProvisionerJobStatusRunning,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: ver.ID,
|
|
BuildNumber: 2,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
InitiatorID: owner.ID,
|
|
JobID: startJob2.ID,
|
|
})
|
|
|
|
// Agent from build 1 should NOT authenticate (latest is not STOP).
|
|
_, err := db.GetAuthenticatedWorkspaceAgentAndBuildByAuthToken(ctx, agent1.AuthToken)
|
|
require.ErrorIs(t, err, sql.ErrNoRows, "agent should not authenticate when latest build is not STOP")
|
|
})
|
|
}
|
|
|
|
// Our `InsertWorkspaceAgentDevcontainers` query should ideally be `[]uuid.NullUUID` but unfortunately
|
|
// sqlc infers it as `[]uuid.UUID`. To ensure we don't insert a `uuid.Nil`, the query inserts NULL when
|
|
// passed with `uuid.Nil`. This test ensures we keep this behavior without regression.
|
|
func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
validSubagent []bool
|
|
}{
|
|
{"BothValid", []bool{true, true}},
|
|
{"FirstValidSecondInvalid", []bool{true, false}},
|
|
{"FirstInvalidSecondValid", []bool{false, true}},
|
|
{"BothInvalid", []bool{false, false}},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
db, _ = dbtestutil.NewDB(t)
|
|
org = dbgen.Organization(t, db, database.Organization{})
|
|
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
OrganizationID: org.ID,
|
|
})
|
|
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{JobID: job.ID})
|
|
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID})
|
|
)
|
|
|
|
ids := make([]uuid.UUID, len(tc.validSubagent))
|
|
names := make([]string, len(tc.validSubagent))
|
|
workspaceFolders := make([]string, len(tc.validSubagent))
|
|
configPaths := make([]string, len(tc.validSubagent))
|
|
subagentIDs := make([]uuid.UUID, len(tc.validSubagent))
|
|
|
|
for i, valid := range tc.validSubagent {
|
|
ids[i] = uuid.New()
|
|
names[i] = fmt.Sprintf("test-devcontainer-%d", i)
|
|
workspaceFolders[i] = fmt.Sprintf("/workspace%d", i)
|
|
configPaths[i] = fmt.Sprintf("/workspace%d/.devcontainer/devcontainer.json", i)
|
|
|
|
if valid {
|
|
subagentIDs[i] = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
ParentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
|
|
}).ID
|
|
} else {
|
|
subagentIDs[i] = uuid.Nil
|
|
}
|
|
}
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
// Given: We insert multiple devcontainer records.
|
|
devcontainers, err := db.InsertWorkspaceAgentDevcontainers(ctx, database.InsertWorkspaceAgentDevcontainersParams{
|
|
WorkspaceAgentID: agent.ID,
|
|
CreatedAt: dbtime.Now(),
|
|
ID: ids,
|
|
Name: names,
|
|
WorkspaceFolder: workspaceFolders,
|
|
ConfigPath: configPaths,
|
|
SubagentID: subagentIDs,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, devcontainers, len(tc.validSubagent))
|
|
|
|
// Then: Verify each devcontainer has the correct SubagentID validity.
|
|
// - When we pass `uuid.Nil`, we get a `uuid.NullUUID{Valid: false}`
|
|
// - When we pass a valid UUID, we get a `uuid.NullUUID{Valid: true}`
|
|
for i, valid := range tc.validSubagent {
|
|
require.Equal(t, valid, devcontainers[i].SubagentID.Valid, "devcontainer %d: subagent_id validity mismatch", i)
|
|
if valid {
|
|
require.Equal(t, subagentIDs[i], devcontainers[i].SubagentID.UUID, "devcontainer %d: subagent_id UUID mismatch", i)
|
|
}
|
|
}
|
|
|
|
// Perform the same check on data returned by
|
|
// `GetWorkspaceAgentDevcontainersByAgentID` to ensure the fix is at
|
|
// the data storage layer, instead of just at a query level.
|
|
fetched, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, fetched, len(tc.validSubagent))
|
|
|
|
// Sort fetched by name to ensure consistent ordering for comparison.
|
|
slices.SortFunc(fetched, func(a, b database.WorkspaceAgentDevcontainer) int {
|
|
return strings.Compare(a.Name, b.Name)
|
|
})
|
|
|
|
for i, valid := range tc.validSubagent {
|
|
require.Equal(t, valid, fetched[i].SubagentID.Valid, "fetched devcontainer %d: subagent_id validity mismatch", i)
|
|
if valid {
|
|
require.Equal(t, subagentIDs[i], fetched[i].SubagentID.UUID, "fetched devcontainer %d: subagent_id UUID mismatch", i)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInsertChatMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
insertModelConfig := func(
|
|
t *testing.T,
|
|
store database.Store,
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
provider string,
|
|
model string,
|
|
displayName string,
|
|
isDefault bool,
|
|
) database.ChatModelConfig {
|
|
t.Helper()
|
|
|
|
modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: provider,
|
|
Model: model,
|
|
DisplayName: displayName,
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: isDefault,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return modelConfig
|
|
}
|
|
|
|
setupChat := func(t *testing.T) (database.Store, context.Context, database.User, database.Chat, string, database.ChatModelConfig) {
|
|
t.Helper()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
provider := "openai"
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: provider,
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelConfigA := insertModelConfig(
|
|
t,
|
|
store,
|
|
ctx,
|
|
user.ID,
|
|
provider,
|
|
"test-model-a-"+uuid.NewString(),
|
|
"Test Model A",
|
|
true,
|
|
)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelConfigA.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return store, ctx, user, chat, provider, modelConfigA
|
|
}
|
|
|
|
insertMessage := func(t *testing.T, store database.Store, ctx context.Context, chatID, userID, modelConfigID uuid.UUID, content string) {
|
|
t.Helper()
|
|
|
|
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{userID},
|
|
ModelConfigID: []uuid.UUID{modelConfigID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
Content: []string{fmt.Sprintf("%q", content)},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
t.Run("ModelSwitchUpdatesLastModelConfigID", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, provider, modelConfigA := setupChat(t)
|
|
modelConfigB := insertModelConfig(
|
|
t,
|
|
store,
|
|
ctx,
|
|
user.ID,
|
|
provider,
|
|
"test-model-b-"+uuid.NewString(),
|
|
"Test Model B",
|
|
false,
|
|
)
|
|
|
|
insertMessage(t, store, ctx, chat.ID, user.ID, modelConfigB.ID, "switch models")
|
|
|
|
gotChat, err := store.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigA.ID, chat.LastModelConfigID)
|
|
require.Equal(t, modelConfigB.ID, gotChat.LastModelConfigID)
|
|
})
|
|
|
|
t.Run("SameModelDoesNotBreakAnything", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, _, modelConfigA := setupChat(t)
|
|
|
|
insertMessage(t, store, ctx, chat.ID, user.ID, modelConfigA.ID, "same model")
|
|
|
|
gotChat, err := store.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, modelConfigA.ID, gotChat.LastModelConfigID)
|
|
})
|
|
|
|
t.Run("BatchInsertMultipleMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ctx, user, chat, _, modelConfigA := setupChat(t)
|
|
|
|
msgs, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{modelConfigA.ID, modelConfigA.ID, modelConfigA.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
|
Content: []string{`"hello"`, `"response"`, `"tool result"`},
|
|
InputTokens: []int64{10, 0, 0},
|
|
OutputTokens: []int64{0, 20, 0},
|
|
TotalTokens: []int64{10, 20, 0},
|
|
ReasoningTokens: []int64{0, 5, 0},
|
|
CacheCreationTokens: []int64{0, 0, 0},
|
|
CacheReadTokens: []int64{0, 0, 0},
|
|
ContextLimit: []int64{0, 0, 0},
|
|
Compressed: []bool{false, false, false},
|
|
TotalCostMicros: []int64{0, 100, 0},
|
|
RuntimeMs: []int64{0, 500, 0},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, msgs, 3)
|
|
|
|
// Verify ordering and roles.
|
|
require.Equal(t, database.ChatMessageRoleUser, msgs[0].Role)
|
|
require.Equal(t, database.ChatMessageRoleAssistant, msgs[1].Role)
|
|
require.Equal(t, database.ChatMessageRoleTool, msgs[2].Role)
|
|
|
|
// Verify IDs are sequential.
|
|
require.Less(t, msgs[0].ID, msgs[1].ID)
|
|
require.Less(t, msgs[1].ID, msgs[2].ID)
|
|
|
|
// Verify nullable fields: user message has CreatedBy set.
|
|
require.True(t, msgs[0].CreatedBy.Valid)
|
|
require.Equal(t, user.ID, msgs[0].CreatedBy.UUID)
|
|
// Assistant and tool messages have NULL CreatedBy.
|
|
require.False(t, msgs[1].CreatedBy.Valid)
|
|
require.False(t, msgs[2].CreatedBy.Valid)
|
|
|
|
// Verify token fields stored as NULL when zero.
|
|
require.True(t, msgs[0].InputTokens.Valid)
|
|
require.Equal(t, int64(10), msgs[0].InputTokens.Int64)
|
|
require.False(t, msgs[0].OutputTokens.Valid) // 0 → NULL
|
|
require.True(t, msgs[1].OutputTokens.Valid)
|
|
require.Equal(t, int64(20), msgs[1].OutputTokens.Int64)
|
|
|
|
// Verify cost: assistant has cost, others NULL.
|
|
require.True(t, msgs[1].TotalCostMicros.Valid)
|
|
require.Equal(t, int64(100), msgs[1].TotalCostMicros.Int64)
|
|
require.False(t, msgs[0].TotalCostMicros.Valid)
|
|
require.False(t, msgs[2].TotalCostMicros.Valid)
|
|
|
|
// Verify runtime_ms on assistant message.
|
|
require.True(t, msgs[1].RuntimeMs.Valid)
|
|
require.Equal(t, int64(500), msgs[1].RuntimeMs.Int64)
|
|
require.False(t, msgs[0].RuntimeMs.Valid)
|
|
})
|
|
}
|
|
|
|
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This test exercises a complex CTE query for prompt
|
|
// reconstruction after compaction. It requires Postgres.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
// Helper: create a chat model config (required FK for chats).
|
|
user := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
// A chat_providers row is required as a FK for model configs.
|
|
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
newChat := func(t *testing.T) database.Chat {
|
|
t.Helper()
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
insertMsg := func(
|
|
t *testing.T,
|
|
chatID uuid.UUID,
|
|
role database.ChatMessageRole,
|
|
vis database.ChatMessageVisibility,
|
|
compressed bool,
|
|
content string,
|
|
) database.ChatMessage {
|
|
t.Helper()
|
|
results, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{uuid.Nil},
|
|
ModelConfigID: []uuid.UUID{uuid.Nil},
|
|
Role: []database.ChatMessageRole{role},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{vis},
|
|
Compressed: []bool{compressed},
|
|
Content: []string{`"` + content + `"`},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
return results[0]
|
|
}
|
|
|
|
msgIDs := func(msgs []database.ChatMessage) []int64 {
|
|
ids := make([]int64, len(msgs))
|
|
for i, m := range msgs {
|
|
ids[i] = m.ID
|
|
}
|
|
return ids
|
|
}
|
|
|
|
t.Run("NoCompaction", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
sys := insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
usr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "hello")
|
|
ast := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "hi there")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got))
|
|
})
|
|
|
|
t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// Messages with visibility=user should NOT appear in the
|
|
// prompt (they are only for the UI).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityUser, false, "user-only msg")
|
|
usr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "hello")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
for _, m := range got {
|
|
require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility,
|
|
"visibility=user messages should not appear in the prompt")
|
|
}
|
|
require.Contains(t, msgIDs(got), usr.ID)
|
|
})
|
|
|
|
t.Run("AfterCompaction", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// Pre-compaction conversation.
|
|
sys := insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
preUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "old question")
|
|
preAsst := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "old answer")
|
|
|
|
// Compaction messages:
|
|
// 1. Summary (role=user, visibility=model, compressed=true).
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "compaction summary")
|
|
// 2. Compressed assistant tool-call (visibility=user).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityUser, true, "tool call")
|
|
// 3. Compressed tool result (visibility=both).
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, true, "tool result")
|
|
|
|
// Post-compaction messages.
|
|
postUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "new question")
|
|
postAsst := insertMsg(t, chat.ID, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, false, "new answer")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
gotIDs := msgIDs(got)
|
|
|
|
// Must include: system prompt, summary, post-compaction.
|
|
require.Contains(t, gotIDs, sys.ID, "system prompt must be included")
|
|
require.Contains(t, gotIDs, summary.ID, "compaction summary must be included")
|
|
require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included")
|
|
require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included")
|
|
|
|
// Must exclude: pre-compaction non-system messages.
|
|
require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded")
|
|
require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded")
|
|
|
|
// Verify ordering.
|
|
require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs)
|
|
})
|
|
|
|
t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// After compaction the summary must appear as role=user so
|
|
// that LLM APIs (e.g. Anthropic) see at least one
|
|
// non-system message in the prompt.
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "summary text")
|
|
newUsr := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "new question")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
hasNonSystem := false
|
|
for _, m := range got {
|
|
if m.Role != "system" {
|
|
hasNonSystem = true
|
|
break
|
|
}
|
|
}
|
|
require.True(t, hasNonSystem,
|
|
"prompt must contain at least one non-system message after compaction")
|
|
require.Contains(t, msgIDs(got), summary.ID)
|
|
require.Contains(t, msgIDs(got), newUsr.ID)
|
|
})
|
|
|
|
t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) {
|
|
t.Parallel()
|
|
chat := newChat(t)
|
|
|
|
// The CTE uses visibility='model' (exact match). If it
|
|
// used IN ('model','both'), the compressed tool result
|
|
// (visibility=both) would be picked as the "summary"
|
|
// instead of the actual summary.
|
|
insertMsg(t, chat.ID, database.ChatMessageRoleSystem, database.ChatMessageVisibilityModel, false, "system prompt")
|
|
summary := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, true, "real summary")
|
|
compressedTool := insertMsg(t, chat.ID, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, true, "tool result")
|
|
postUser := insertMsg(t, chat.ID, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, false, "follow-up")
|
|
|
|
got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
gotIDs := msgIDs(got)
|
|
require.Contains(t, gotIDs, summary.ID, "real summary must be included")
|
|
require.NotContains(t, gotIDs, compressedTool.ID,
|
|
"compressed tool result must not be included")
|
|
require.Contains(t, gotIDs, postUser.ID)
|
|
})
|
|
}
|
|
|
|
func TestGetWorkspaceBuildMetricsByResourceID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
|
CreatedBy: user.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: tmpl.ID,
|
|
OwnerID: user.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: tv.ID,
|
|
JobID: job.ID,
|
|
InitiatorID: user.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
parentReadyAt := dbtime.Now()
|
|
parentStartedAt := parentReadyAt.Add(-time.Second)
|
|
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, row.AllAgentsReady)
|
|
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
|
require.Equal(t, "success", row.WorstStatus)
|
|
})
|
|
|
|
t.Run("SubAgentExcluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
user := dbgen.User(t, db, database.User{})
|
|
tmpl := dbgen.Template(t, db, database.Template{
|
|
OrganizationID: org.ID,
|
|
CreatedBy: user.ID,
|
|
})
|
|
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: org.ID,
|
|
TemplateID: uuid.NullUUID{UUID: tmpl.ID, Valid: true},
|
|
CreatedBy: user.ID,
|
|
})
|
|
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
|
OrganizationID: org.ID,
|
|
TemplateID: tmpl.ID,
|
|
OwnerID: user.ID,
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
|
OrganizationID: org.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: ws.ID,
|
|
TemplateVersionID: tv.ID,
|
|
JobID: job.ID,
|
|
InitiatorID: user.ID,
|
|
})
|
|
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
|
JobID: job.ID,
|
|
})
|
|
|
|
parentReadyAt := dbtime.Now()
|
|
parentStartedAt := parentReadyAt.Add(-time.Second)
|
|
parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
|
ResourceID: resource.ID,
|
|
StartedAt: sql.NullTime{Time: parentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: parentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
// Sub-agent with ready_at 1 hour later should be excluded.
|
|
subAgentReadyAt := parentReadyAt.Add(time.Hour)
|
|
subAgentStartedAt := subAgentReadyAt.Add(-time.Second)
|
|
_ = dbgen.WorkspaceSubAgent(t, db, parentAgent, database.WorkspaceAgent{
|
|
StartedAt: sql.NullTime{Time: subAgentStartedAt, Valid: true},
|
|
ReadyAt: sql.NullTime{Time: subAgentReadyAt, Valid: true},
|
|
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
|
})
|
|
|
|
row, err := db.GetWorkspaceBuildMetricsByResourceID(ctx, resource.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, row.AllAgentsReady)
|
|
// LastAgentReadyAt should be the parent's, not the sub-agent's.
|
|
require.True(t, parentReadyAt.Equal(row.LastAgentReadyAt))
|
|
require.Equal(t, "success", row.WorstStatus)
|
|
})
|
|
}
|
|
|
|
// TestUpsertAISeats verifies 'UpsertAISeatState' only returns true when a new
|
|
// row is inserted.
|
|
func TestUpsertAISeats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
now := dbtime.Now()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
newRow, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now.Add(time.Hour * -24),
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, newRow)
|
|
|
|
alreadyExists, err := db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now.Add(time.Hour * -23),
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, alreadyExists)
|
|
|
|
alreadyExists, err = db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
|
UserID: user.ID,
|
|
FirstUsedAt: now,
|
|
LastEventType: database.AiSeatUsageReasonTask,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, alreadyExists)
|
|
}
|
|
|
|
func TestGetPRInsights(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
// setupChatInfra creates a fresh database with a user, chat provider,
|
|
// and model config. Returns the store, user ID, model config ID,
|
|
// and org ID.
|
|
setupChatInfra := func(t *testing.T) (database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
|
|
t.Helper()
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "anthropic",
|
|
DisplayName: "Anthropic",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "anthropic",
|
|
Model: "claude-4",
|
|
DisplayName: "Claude 4",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
return store, user.ID, mc.ID, org.ID
|
|
}
|
|
|
|
type chatParams struct {
|
|
Store database.Store
|
|
UserID uuid.UUID
|
|
ModelConfigID uuid.UUID
|
|
OrgID uuid.UUID
|
|
}
|
|
|
|
createChat := func(t *testing.T, p chatParams, title string) database.Chat {
|
|
t.Helper()
|
|
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
|
OrganizationID: p.OrgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: p.UserID,
|
|
LastModelConfigID: p.ModelConfigID,
|
|
Title: title,
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
// insertCostMessage inserts a single assistant message with the
|
|
// given total_cost_micros value.
|
|
insertCostMessage := func(t *testing.T, store database.Store, chatID, userID, mcID uuid.UUID, costMicros int64) {
|
|
t.Helper()
|
|
_, err := store.InsertChatMessages(context.Background(), database.InsertChatMessagesParams{
|
|
ChatID: chatID,
|
|
CreatedBy: []uuid.UUID{userID},
|
|
ModelConfigID: []uuid.UUID{mcID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
Content: []string{`[{"type":"text","text":"hello"}]`},
|
|
ContentVersion: []int16{1},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{costMicros},
|
|
RuntimeMs: []int64{0},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// linkPR associates a chat with a pull request via
|
|
// UpsertChatDiffStatus.
|
|
linkPR := func(t *testing.T, store database.Store, chatID uuid.UUID, prURL, state, title string, additions, deletions, changed int32) {
|
|
t.Helper()
|
|
now := time.Now()
|
|
_, err := store.UpsertChatDiffStatus(context.Background(), database.UpsertChatDiffStatusParams{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: prURL, Valid: true},
|
|
PullRequestState: sql.NullString{String: state, Valid: true},
|
|
PullRequestTitle: title,
|
|
Additions: additions,
|
|
Deletions: deletions,
|
|
ChangedFiles: changed,
|
|
RefreshedAt: now,
|
|
StaleAt: now.Add(time.Hour),
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
startDate := time.Now().Add(-24 * time.Hour)
|
|
endDate := time.Now().Add(time.Hour)
|
|
noOwner := uuid.NullUUID{}
|
|
|
|
t.Run("MultipleChatsSamePR_CostSummed", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
chatA := createChat(t, p, "chat-A")
|
|
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000) // $5
|
|
|
|
chatB := createChat(t, p, "chat-B")
|
|
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000) // $3
|
|
|
|
prURL := "https://github.com/org/repo/pull/123"
|
|
linkPR(t, store, chatA.ID, prURL, "merged", "fix: something", 100, 20, 5)
|
|
linkPR(t, store, chatB.ID, prURL, "merged", "fix: something", 100, 20, 5)
|
|
|
|
// Both chats reference the same PR. The pr_costs CTE sums
|
|
// cost across all chats for the same PR URL, so the total
|
|
// should be $5 + $3 = $8. The PR itself is counted once.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(8_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("DifferentPRs_NoDuplication", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
chatA := createChat(t, p, "chat-A")
|
|
insertCostMessage(t, store, chatA.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, chatA.ID, "https://github.com/org/repo/pull/1", "merged", "feat: A", 50, 10, 2)
|
|
|
|
chatB := createChat(t, p, "chat-B")
|
|
insertCostMessage(t, store, chatB.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, chatB.ID, "https://github.com/org/repo/pull/2", "open", "feat: B", 80, 30, 4)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros) // $5 + $3
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
|
|
// RecentPRs ordered by created_at DESC: chatB is newer.
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// Costs must not be mixed across different PRs.
|
|
assert.Equal(t, int64(3_000_000), recent[0].CostMicros) // PR 2 (newer)
|
|
assert.Equal(t, int64(5_000_000), recent[1].CostMicros) // PR 1 (older)
|
|
})
|
|
|
|
// createChildChat creates a chat with ParentChatID and RootChatID
|
|
// set, simulating a subagent/child chat in a tree.
|
|
createChildChat := func(t *testing.T, p chatParams, parentID, rootID uuid.UUID, title string) database.Chat {
|
|
t.Helper()
|
|
chat, err := p.Store.InsertChat(context.Background(), database.InsertChatParams{
|
|
OrganizationID: p.OrgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: p.UserID,
|
|
LastModelConfigID: p.ModelConfigID,
|
|
Title: title,
|
|
ParentChatID: uuid.NullUUID{UUID: parentID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: rootID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
t.Run("DuplicatePRUrl_CountedOnce", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
prURL := "https://github.com/org/repo/pull/99"
|
|
for i := range 3 {
|
|
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
|
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
|
linkPR(t, store, chat.ID, prURL, "merged", "fix: same PR", 40, 10, 3)
|
|
}
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
})
|
|
|
|
t.Run("ChildChatCostsIncluded", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent chat with a $5 cost.
|
|
parent := createChat(t, p, "parent-chat")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 5_000_000)
|
|
|
|
// Two child chats (subagents) with $2 each. Only the parent
|
|
// has a chat_diff_statuses entry, but the children's costs
|
|
// should be included via the tree join.
|
|
child1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, child1.ID, userID, mcID, 2_000_000)
|
|
|
|
child2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, child2.ID, userID, mcID, 2_000_000)
|
|
|
|
prURL := "https://github.com/org/repo/pull/42"
|
|
linkPR(t, store, parent.ID, prURL, "merged", "feat: tree cost", 60, 15, 3)
|
|
|
|
// Summary should reflect $5 + $2 + $2 = $9 total.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
|
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
|
|
|
|
// RecentPRs should return 1 row with the full tree cost.
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(9_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("SiblingPRs_NoCrossContamination", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent chat with $10 orchestration cost.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
|
|
// Child C1 ($5) creates PR1.
|
|
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/10", "merged", "feat: PR1", 50, 10, 2)
|
|
|
|
// Child C2 ($3) creates PR2.
|
|
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, c2.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, c2.ID, "https://github.com/org/repo/pull/11", "open", "feat: PR2", 30, 5, 1)
|
|
|
|
// With direct-branch attribution:
|
|
// PR1 cost = C1's own cost = $5 (parent NOT included — only children of C1)
|
|
// PR2 cost = C2's own cost = $3
|
|
// Total = $8 (no double-counting of parent or siblings)
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// PR2 (newer) = $3, PR1 (older) = $5.
|
|
assert.Equal(t, int64(3_000_000), recent[0].CostMicros)
|
|
assert.Equal(t, int64(5_000_000), recent[1].CostMicros)
|
|
})
|
|
|
|
t.Run("ParentAndChildDifferentPRs_NoCrossContamination", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent P ($10) creates PR1.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
linkPR(t, store, parent.ID, "https://github.com/org/repo/pull/20", "merged", "feat: parent PR", 80, 20, 4)
|
|
|
|
// Child C1 ($5) has its own PR2. Because C1 has its own
|
|
// chat_diff_statuses entry, its cost should NOT be included
|
|
// under PR1 — it belongs to PR2 only.
|
|
c1 := createChildChat(t, p, parent.ID, parent.ID, "child-1")
|
|
insertCostMessage(t, store, c1.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, c1.ID, "https://github.com/org/repo/pull/21", "open", "feat: child PR", 30, 5, 1)
|
|
|
|
// Child C2 ($2) has NO cds entry — pure subagent.
|
|
// Its cost should be included under PR1 (the parent's PR).
|
|
c2 := createChildChat(t, p, parent.ID, parent.ID, "child-2")
|
|
insertCostMessage(t, store, c2.ID, userID, mcID, 2_000_000)
|
|
|
|
// PR1 cost = parent ($10) + C2 ($2) = $12 (C1 excluded)
|
|
// PR2 cost = C1 ($5)
|
|
// Total = $17 (actual spend: $10 + $5 + $2 = $17)
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
// PR2/C1 (newer) = $5, PR1/parent (older) = $12.
|
|
assert.Equal(t, int64(5_000_000), recent[0].CostMicros)
|
|
assert.Equal(t, int64(12_000_000), recent[1].CostMicros)
|
|
})
|
|
|
|
t.Run("EmptyURLNotCollapsed", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Two chats with empty-string URLs should be treated as
|
|
// separate PRs (NULLIF converts '' to NULL, falling back
|
|
// to c.id::text).
|
|
chatX := createChat(t, p, "chat-X")
|
|
insertCostMessage(t, store, chatX.ID, userID, mcID, 4_000_000)
|
|
linkPR(t, store, chatX.ID, "", "open", "draft: X", 10, 2, 1)
|
|
|
|
chatY := createChat(t, p, "chat-Y")
|
|
insertCostMessage(t, store, chatY.ID, userID, mcID, 6_000_000)
|
|
linkPR(t, store, chatY.ID, "", "merged", "draft: Y", 20, 5, 2)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 2)
|
|
})
|
|
|
|
t.Run("ParentAndChildSameURL_DedupedWithCombinedCost", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Parent P ($10) links to a PR.
|
|
parent := createChat(t, p, "parent")
|
|
insertCostMessage(t, store, parent.ID, userID, mcID, 10_000_000)
|
|
|
|
// Child C ($5) also links to the same PR URL.
|
|
child := createChildChat(t, p, parent.ID, parent.ID, "child")
|
|
insertCostMessage(t, store, child.ID, userID, mcID, 5_000_000)
|
|
|
|
prURL := "https://github.com/org/repo/pull/50"
|
|
linkPR(t, store, parent.ID, prURL, "merged", "feat: shared PR", 70, 15, 3)
|
|
linkPR(t, store, child.ID, prURL, "merged", "feat: shared PR", 70, 15, 3)
|
|
|
|
// Both parent and child have cds entries for the same URL.
|
|
// The PR should be counted once with combined cost $10 + $5 = $15.
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(15_000_000), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("ZeroCostChat_StillCounted", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// A chat linked to a PR but with NO chat_messages at all.
|
|
// The PR should still appear with zero cost.
|
|
chat := createChat(t, p, "zero-cost-chat")
|
|
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/60", "open", "feat: no messages", 25, 5, 2)
|
|
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
|
assert.Equal(t, int64(0), summary.TotalCostMicros)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, int64(0), recent[0].CostMicros)
|
|
})
|
|
|
|
t.Run("BlankDisplayNameFallsBackToModel", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, _, orgID := setupChatInfra(t)
|
|
|
|
const modelName = "claude-4.1"
|
|
emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{
|
|
Provider: "anthropic",
|
|
Model: modelName,
|
|
DisplayName: "",
|
|
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: userID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: false,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: emptyDisplayModel.ID, OrgID: orgID}
|
|
chat := createChat(t, p, "chat-empty-display-name")
|
|
insertCostMessage(t, store, chat.ID, userID, emptyDisplayModel.ID, 1_000_000)
|
|
linkPR(t, store, chat.ID, "https://github.com/org/repo/pull/72", "merged", "fix: blank display name", 10, 2, 1)
|
|
|
|
byModel, err := store.GetPRInsightsPerModel(context.Background(), database.GetPRInsightsPerModelParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, byModel, 1)
|
|
assert.Equal(t, modelName, byModel[0].DisplayName)
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, recent, 1)
|
|
assert.Equal(t, modelName, recent[0].ModelDisplayName)
|
|
})
|
|
|
|
t.Run("MergedCostMicros_OnlyCountsMerged", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Merged PR with $5 cost.
|
|
chatMerged := createChat(t, p, "chat-merged")
|
|
insertCostMessage(t, store, chatMerged.ID, userID, mcID, 5_000_000)
|
|
linkPR(t, store, chatMerged.ID, "https://github.com/org/repo/pull/70", "merged", "fix: merged", 40, 10, 2)
|
|
|
|
// Open PR with $3 cost.
|
|
chatOpen := createChat(t, p, "chat-open")
|
|
insertCostMessage(t, store, chatOpen.ID, userID, mcID, 3_000_000)
|
|
linkPR(t, store, chatOpen.ID, "https://github.com/org/repo/pull/71", "open", "feat: open", 20, 5, 1)
|
|
|
|
// TotalCostMicros includes both ($5 + $3 = $8), but
|
|
// MergedCostMicros only includes the merged PR ($5).
|
|
summary, err := store.GetPRInsightsSummary(context.Background(), database.GetPRInsightsSummaryParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
|
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
|
})
|
|
|
|
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
|
t.Parallel()
|
|
store, userID, mcID, orgID := setupChatInfra(t)
|
|
p := chatParams{Store: store, UserID: userID, ModelConfigID: mcID, OrgID: orgID}
|
|
|
|
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
|
// verify all are returned.
|
|
const prCount = 25
|
|
for i := range prCount {
|
|
chat := createChat(t, p, fmt.Sprintf("chat-%d", i))
|
|
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
|
linkPR(t, store, chat.ID,
|
|
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
|
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
|
|
}
|
|
|
|
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
|
StartDate: startDate,
|
|
EndDate: endDate,
|
|
OwnerID: noOwner,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
|
|
})
|
|
}
|
|
|
|
func TestChatPinOrderQueries(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
setup := func(t *testing.T) (context.Context, database.Store, uuid.UUID, uuid.UUID, uuid.UUID) {
|
|
t.Helper()
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
// Use background context for fixture setup so the
|
|
// timed test context doesn't tick during DB init.
|
|
bg := context.Background()
|
|
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
return ctx, db, owner.ID, modelCfg.ID, org.ID
|
|
}
|
|
|
|
createChat := func(t *testing.T, ctx context.Context, db database.Store, ownerID, modelCfgID, orgID uuid.UUID, title string) database.Chat {
|
|
t.Helper()
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: orgID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelCfgID,
|
|
Title: title,
|
|
})
|
|
require.NoError(t, err)
|
|
return chat
|
|
}
|
|
|
|
requirePinOrders := func(t *testing.T, ctx context.Context, db database.Store, want map[uuid.UUID]int32) {
|
|
t.Helper()
|
|
|
|
for chatID, wantPinOrder := range want {
|
|
chat, err := db.GetChatByID(ctx, chatID)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, wantPinOrder, chat.PinOrder)
|
|
}
|
|
}
|
|
|
|
t.Run("PinChatByIDAppendsWithinOwner", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
otherOwner := dbgen.User(t, db, database.User{})
|
|
other := createChat(t, ctx, db, otherOwner.ID, modelCfgID, orgID, "other-owner")
|
|
|
|
require.NoError(t, db.PinChatByID(ctx, other.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, first.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, second.ID))
|
|
require.NoError(t, db.PinChatByID(ctx, third.ID))
|
|
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 2,
|
|
third.ID: 3,
|
|
other.ID: 1,
|
|
})
|
|
})
|
|
|
|
t.Run("UpdateChatPinOrderShiftsNeighborsAndClamps", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 1,
|
|
}))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 2,
|
|
second.ID: 3,
|
|
third.ID: 1,
|
|
})
|
|
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 99,
|
|
}))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 2,
|
|
third.ID: 3,
|
|
})
|
|
})
|
|
|
|
t.Run("UnpinChatByIDCompactsPinnedChats", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
require.NoError(t, db.UnpinChatByID(ctx, second.ID))
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 0,
|
|
third.ID: 2,
|
|
})
|
|
})
|
|
|
|
t.Run("ArchiveClearsPinAndExcludesFromRanking", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, db, ownerID, modelCfgID, orgID := setup(t)
|
|
first := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "first")
|
|
second := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "second")
|
|
third := createChat(t, ctx, db, ownerID, modelCfgID, orgID, "third")
|
|
|
|
for _, chat := range []database.Chat{first, second, third} {
|
|
require.NoError(t, db.PinChatByID(ctx, chat.ID))
|
|
}
|
|
|
|
// Archive the middle pin.
|
|
_, err := db.ArchiveChatByID(ctx, second.ID)
|
|
require.NoError(t, err)
|
|
|
|
// Archived chat should have pin_order cleared. Remaining
|
|
// pins keep their original positions; the next mutation
|
|
// compacts via ROW_NUMBER().
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 1,
|
|
second.ID: 0,
|
|
third.ID: 3,
|
|
})
|
|
|
|
// Reorder among remaining active pins — archived chat
|
|
// should not interfere with position calculation.
|
|
require.NoError(t, db.UpdateChatPinOrder(ctx, database.UpdateChatPinOrderParams{
|
|
ID: third.ID,
|
|
PinOrder: 1,
|
|
}))
|
|
// After reorder, ROW_NUMBER() compacts the sequence.
|
|
requirePinOrders(t, ctx, db, map[uuid.UUID]int32{
|
|
first.ID: 2,
|
|
second.ID: 0,
|
|
third.ID: 1,
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestChatPinOrderConstraints(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
db, _ := dbtestutil.NewDB(t)
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
owner := dbgen.User(t, db, database.User{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
bg := context.Background()
|
|
_, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
t.Run("ChildChatCannotBePinned", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
parent, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "parent",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
child, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "child",
|
|
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
err = db.PinChatByID(ctx, child.ID)
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderParentCheck))
|
|
})
|
|
|
|
t.Run("ArchivedChatCannotBePinned", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusCompleted,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "will be archived",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
err = db.PinChatByID(ctx, chat.ID)
|
|
require.Error(t, err)
|
|
require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderArchivedCheck))
|
|
})
|
|
}
|
|
|
|
func TestChatLabels(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
owner := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
t.Run("CreateWithLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
labels := database.StringMap{"github.repo": "coder/coder", "env": "prod"}
|
|
labelsJSON, err := json.Marshal(labels)
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "labeled-chat",
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.StringMap{"github.repo": "coder/coder", "env": "prod"}, chat.Labels)
|
|
|
|
// Read back and verify.
|
|
fetched, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat.Labels, fetched.Labels)
|
|
})
|
|
|
|
t.Run("CreateWithoutLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "no-labels-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
// Default should be an empty map, not nil.
|
|
require.NotNil(t, chat.Labels)
|
|
require.Empty(t, chat.Labels)
|
|
})
|
|
|
|
t.Run("UpdateLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "update-labels-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, chat.Labels)
|
|
|
|
// Set labels.
|
|
newLabels, err := json.Marshal(database.StringMap{"team": "backend"})
|
|
require.NoError(t, err)
|
|
updated, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
|
ID: chat.ID,
|
|
Labels: newLabels,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, database.StringMap{"team": "backend"}, updated.Labels)
|
|
|
|
// Title should be unchanged.
|
|
require.Equal(t, "update-labels-chat", updated.Title)
|
|
|
|
// Clear labels by setting empty object.
|
|
emptyLabels, err := json.Marshal(database.StringMap{})
|
|
require.NoError(t, err)
|
|
cleared, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
|
ID: chat.ID,
|
|
Labels: emptyLabels,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, cleared.Labels)
|
|
})
|
|
|
|
t.Run("UpdateTitleDoesNotAffectLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
labels := database.StringMap{"pr": "1234"}
|
|
labelsJSON, err := json.Marshal(labels)
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "original-title",
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update title only — labels must survive.
|
|
updated, err := db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
|
ID: chat.ID,
|
|
Title: "new-title",
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, "new-title", updated.Title)
|
|
require.Equal(t, database.StringMap{"pr": "1234"}, updated.Labels)
|
|
})
|
|
|
|
t.Run("FilterByLabels", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
// Create three chats with different labels.
|
|
for _, tc := range []struct {
|
|
title string
|
|
labels database.StringMap
|
|
}{
|
|
{"filter-a", database.StringMap{"env": "prod", "team": "backend"}},
|
|
{"filter-b", database.StringMap{"env": "prod", "team": "frontend"}},
|
|
{"filter-c", database.StringMap{"env": "staging"}},
|
|
} {
|
|
labelsJSON, err := json.Marshal(tc.labels)
|
|
require.NoError(t, err)
|
|
_, err = db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID, Title: tc.title,
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Filter by env=prod — should match filter-a and filter-b.
|
|
filterJSON, err := json.Marshal(database.StringMap{"env": "prod"})
|
|
require.NoError(t, err)
|
|
results, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
LabelFilter: pqtype.NullRawMessage{
|
|
RawMessage: filterJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
titles := make([]string, 0, len(results))
|
|
for _, c := range results {
|
|
titles = append(titles, c.Chat.Title)
|
|
}
|
|
require.Contains(t, titles, "filter-a")
|
|
require.Contains(t, titles, "filter-b")
|
|
require.NotContains(t, titles, "filter-c")
|
|
|
|
// Filter by env=prod AND team=backend — should match only filter-a.
|
|
filterJSON, err = json.Marshal(database.StringMap{"env": "prod", "team": "backend"})
|
|
require.NoError(t, err)
|
|
results, err = db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
LabelFilter: pqtype.NullRawMessage{
|
|
RawMessage: filterJSON,
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 1)
|
|
require.Equal(t, "filter-a", results[0].Chat.Title)
|
|
// No filter — should return all chats for this owner.
|
|
allChats, err := db.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: owner.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.GreaterOrEqual(t, len(allChats), 3)
|
|
})
|
|
}
|
|
|
|
func TestUpdateChatLastTurnSummary(t *testing.T) {
|
|
t.Parallel()
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
sqlDB := testSQLDB(t)
|
|
err := migrations.Up(sqlDB)
|
|
require.NoError(t, err)
|
|
db := database.New(sqlDB)
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
owner := dbgen.User(t, db, database.User{})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID})
|
|
|
|
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: owner.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "summary-chat",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affected, err := db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "resolved the issue", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
fetched, err := db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sql.NullString{String: "resolved the issue", Valid: true}, fetched.LastTurnSummary)
|
|
require.Equal(t, chat.UpdatedAt, fetched.UpdatedAt)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: " \n\t ", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
fetched, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, fetched.LastTurnSummary.Valid)
|
|
require.Equal(t, chat.UpdatedAt, fetched.UpdatedAt)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "fresh summary", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, affected)
|
|
|
|
advancedUpdatedAt := chat.UpdatedAt.Add(time.Second)
|
|
_, err = db.UpdateChatStatusPreserveUpdatedAt(ctx, database.UpdateChatStatusPreserveUpdatedAtParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusRunning,
|
|
UpdatedAt: advancedUpdatedAt,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affected, err = db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedUpdatedAt: chat.UpdatedAt,
|
|
LastTurnSummary: sql.NullString{String: "stale summary", Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Zero(t, affected)
|
|
|
|
fetched, err = db.GetChatByID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary)
|
|
require.Equal(t, advancedUpdatedAt, fetched.UpdatedAt)
|
|
}
|
|
|
|
func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-rollback-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
const cutoff int64 = 50
|
|
|
|
affectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 10, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 5, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
affectedByStepHistoryTipRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedByStepHistoryTipRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "interrupted",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 7, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// affectedByStepAssistantMsgRun: run-level fields are at/below
|
|
// the cutoff, but its step has assistant_message_id above the
|
|
// cutoff. This exercises the step.assistant_message_id > cutoff
|
|
// branch of the UNION independently of history_tip_message_id.
|
|
affectedByStepAssistantMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: affectedByStepAssistantMsgRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff + 3, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
unaffectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
unaffectedStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: unaffectedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
AssistantMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: cutoff,
|
|
StartedBefore: time.Now().Add(time.Minute),
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 3, deletedRows)
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedSteps)
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedByStepHistoryTipRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedByStepHistoryTipSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepHistoryTipRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedByStepHistoryTipSteps)
|
|
|
|
// Verify the run caught by step-level assistant_message_id is
|
|
// also deleted. This would survive if the
|
|
// step.assistant_message_id > @message_id clause were removed.
|
|
_, err = store.GetChatDebugRunByID(ctx, affectedByStepAssistantMsgRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
affectedByStepAssistantMsgSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepAssistantMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, affectedByStepAssistantMsgSteps)
|
|
|
|
remainingRuns, err := store.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
|
ChatID: chat.ID,
|
|
LimitVal: 100,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingRuns, 1)
|
|
require.Equal(t, unaffectedRun.ID, remainingRuns[0].ID)
|
|
|
|
remainingRun, err := store.GetChatDebugRunByID(ctx, unaffectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, unaffectedRun.ID, remainingRun.ID)
|
|
|
|
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, unaffectedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingSteps, 1)
|
|
require.Equal(t, unaffectedStep.ID, remainingSteps[0].ID)
|
|
}
|
|
|
|
func TestFinalizeStaleChatDebugRows(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-finalize-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-finalize-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleTime is well before the threshold so rows stamped with it
|
|
// are considered stale. The threshold sits between staleTime and
|
|
// NOW(), letting us create rows that are stale-by-age and rows
|
|
// that are fresh-by-age in the same test.
|
|
staleTime := time.Now().Add(-2 * time.Hour)
|
|
staleThreshold := time.Now().Add(-1 * time.Hour)
|
|
|
|
// --- staleRun: in_progress run with no finished_at --- should be
|
|
// finalized.
|
|
staleRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleStep: in_progress step attached to staleRun.
|
|
staleStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: staleRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- orphanStep: in_progress step whose run is already completed ---
|
|
// Its own updated_at is old, so it should be finalized directly.
|
|
// The step must be inserted while the run is still open because
|
|
// InsertChatDebugStep requires finished_at IS NULL on the parent
|
|
// run (atomic guard against appending steps to finalized runs).
|
|
completedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert the step while the run is still open (finished_at IS NULL).
|
|
orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: completedRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now mark the run as completed with a finished_at timestamp,
|
|
// leaving the step orphaned in in_progress state.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: completedRun.ID,
|
|
ChatID: completedRun.ChatID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- cascadeRun: stale in_progress run with a FRESH step ---
|
|
// The run's updated_at is old so the run itself is finalized by
|
|
// age. The step's updated_at is recent (default NOW()), so it is
|
|
// NOT caught by the age predicate. It must be finalized solely
|
|
// via the cascade CTE clause: run_id IN (SELECT id FROM
|
|
// finalized_runs). Removing that clause would leave this step
|
|
// stuck in 'in_progress'.
|
|
cascadeRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// cascadeStep: recent updated_at (default NOW()), so only the
|
|
// cascade path can finalize it.
|
|
cascadeStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: cascadeRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// The InsertChatDebugStep CTE atomically bumps the parent run's
|
|
// updated_at to NOW(). Reset it back to staleTime so the run is
|
|
// still caught by the age predicate in FinalizeStaleChatDebugRows.
|
|
err = store.TouchChatDebugRunUpdatedAt(ctx, database.TouchChatDebugRunUpdatedAtParams{
|
|
ID: cascadeRun.ID,
|
|
ChatID: chat.ID,
|
|
Now: staleTime,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- alreadyDone: completed run/step --- should NOT be touched.
|
|
doneRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert step while run is still open.
|
|
doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: doneRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "completed",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now finalize both run and step.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: doneRun.ID,
|
|
ChatID: doneRun.ChatID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: doneStep.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- errorRun: error run/step --- should NOT be touched either,
|
|
// exercising the 'error' branch of the NOT IN clause.
|
|
errorRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "error",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert step while run is still open.
|
|
errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: errorRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "error",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Now finalize both run and step.
|
|
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: errorRun.ID,
|
|
ChatID: errorRun.ChatID,
|
|
Status: sql.NullString{String: "error", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: errorStep.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "error", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- freshRun: recent in_progress run with current timestamp ---
|
|
// should NOT be finalized because its updated_at is after the
|
|
// threshold, exercising the age predicate (not just terminal
|
|
// status) as the survival reason.
|
|
freshRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 20, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 20, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
// UpdatedAt defaults to NOW(), which is after staleThreshold.
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
freshStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: freshRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
// UpdatedAt defaults to NOW().
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// --- Execute the finalization sweep. ---
|
|
result, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{
|
|
Now: time.Now(),
|
|
UpdatedBefore: staleThreshold,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// staleRun + cascadeRun were finalized; completedRun and doneRun
|
|
// were already terminal, and freshRun survives because its
|
|
// updated_at is after the threshold — so only 2 runs are expected.
|
|
assert.EqualValues(t, 2, result.RunsFinalized,
|
|
"stale + cascade in_progress runs should be finalized")
|
|
// staleStep (age), orphanStep (age), cascadeStep (cascade only)
|
|
// should all be finalized.
|
|
assert.EqualValues(t, 3, result.StepsFinalized,
|
|
"stale step + orphan step + cascade step should all be finalized")
|
|
|
|
// Verify the stale run was set to interrupted.
|
|
updatedStaleRun, err := store.GetChatDebugRunByID(ctx, staleRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "interrupted", updatedStaleRun.Status)
|
|
assert.True(t, updatedStaleRun.FinishedAt.Valid,
|
|
"finalized run should have a finished_at timestamp")
|
|
|
|
// Verify the stale step was set to interrupted.
|
|
staleSteps, err := store.GetChatDebugStepsByRunID(ctx, staleRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, staleSteps, 1)
|
|
assert.Equal(t, staleStep.ID, staleSteps[0].ID)
|
|
assert.Equal(t, "interrupted", staleSteps[0].Status)
|
|
assert.True(t, staleSteps[0].FinishedAt.Valid,
|
|
"finalized step should have a finished_at timestamp")
|
|
|
|
// Verify the orphan step was also finalized.
|
|
orphanSteps, err := store.GetChatDebugStepsByRunID(ctx, completedRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, orphanSteps, 1)
|
|
assert.Equal(t, orphanStep.ID, orphanSteps[0].ID)
|
|
assert.Equal(t, "interrupted", orphanSteps[0].Status)
|
|
|
|
// Verify the cascade run was finalized.
|
|
updatedCascadeRun, err := store.GetChatDebugRunByID(ctx, cascadeRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "interrupted", updatedCascadeRun.Status)
|
|
assert.True(t, updatedCascadeRun.FinishedAt.Valid,
|
|
"cascade run should have a finished_at timestamp")
|
|
|
|
// Verify the cascade step was finalized despite its recent
|
|
// updated_at, proving the cascade CTE clause is required.
|
|
cascadeSteps, err := store.GetChatDebugStepsByRunID(ctx, cascadeRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, cascadeSteps, 1)
|
|
assert.Equal(t, cascadeStep.ID, cascadeSteps[0].ID)
|
|
assert.Equal(t, "interrupted", cascadeSteps[0].Status,
|
|
"fresh step should be finalized via cascade, not age")
|
|
assert.True(t, cascadeSteps[0].FinishedAt.Valid,
|
|
"cascade step should have a finished_at timestamp")
|
|
|
|
// Verify the completed run/step are untouched.
|
|
unchangedRun, err := store.GetChatDebugRunByID(ctx, doneRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "completed", unchangedRun.Status)
|
|
|
|
doneSteps, err := store.GetChatDebugStepsByRunID(ctx, doneRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, doneSteps, 1)
|
|
assert.Equal(t, "completed", doneSteps[0].Status)
|
|
|
|
// Verify the error run/step are untouched.
|
|
unchangedErrorRun, err := store.GetChatDebugRunByID(ctx, errorRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "error", unchangedErrorRun.Status)
|
|
|
|
errorSteps, err := store.GetChatDebugStepsByRunID(ctx, errorRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, errorSteps, 1)
|
|
assert.Equal(t, "error", errorSteps[0].Status)
|
|
|
|
// Verify the fresh in_progress run survived due to recency,
|
|
// not terminal status — its updated_at is after the threshold.
|
|
unchangedFreshRun, err := store.GetChatDebugRunByID(ctx, freshRun.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "in_progress", unchangedFreshRun.Status,
|
|
"fresh in_progress run must survive due to recency")
|
|
assert.False(t, unchangedFreshRun.FinishedAt.Valid,
|
|
"fresh run should not have a finished_at timestamp")
|
|
|
|
freshSteps, err := store.GetChatDebugStepsByRunID(ctx, freshRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, freshSteps, 1)
|
|
assert.Equal(t, freshStep.ID, freshSteps[0].ID)
|
|
assert.Equal(t, "in_progress", freshSteps[0].Status,
|
|
"fresh in_progress step must survive due to recency")
|
|
assert.False(t, freshSteps[0].FinishedAt.Valid,
|
|
"fresh step should not have a finished_at timestamp")
|
|
|
|
// A second sweep should be a no-op.
|
|
result2, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{
|
|
Now: time.Now(),
|
|
UpdatedBefore: staleThreshold,
|
|
})
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, 0, result2.RunsFinalized,
|
|
"second sweep should find nothing to finalize")
|
|
assert.EqualValues(t, 0, result2.StepsFinalized,
|
|
"second sweep should find nothing to finalize")
|
|
}
|
|
|
|
func TestChatDebugSQLGuards(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-guards-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatA, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-guard-A-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chatB, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-guard-B-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
runA, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chatA.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
stepA, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: runA.ID,
|
|
ChatID: chatA.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// InsertChatDebugStep: valid run_id but chat_id belongs to a
|
|
// different chat. The INSERT...SELECT guard should produce zero
|
|
// rows, surfacing as sql.ErrNoRows.
|
|
t.Run("InsertChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: runA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
StepNumber: 2,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"InsertChatDebugStep should fail when chat_id does not match the run's chat_id")
|
|
})
|
|
|
|
// UpdateChatDebugRun: valid run ID but wrong chat_id.
|
|
t.Run("UpdateChatDebugRun_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: runA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"UpdateChatDebugRun should fail when chat_id does not match")
|
|
})
|
|
|
|
// UpdateChatDebugStep: valid step ID but wrong chat_id.
|
|
t.Run("UpdateChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
_, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: stepA.ID,
|
|
ChatID: chatB.ID, // wrong chat
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
Now: time.Now(),
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows,
|
|
"UpdateChatDebugStep should fail when chat_id does not match")
|
|
})
|
|
}
|
|
|
|
// TestChatDebugRunCOALESCEPreservation verifies that the COALESCE
|
|
// pattern in UpdateChatDebugRun preserves every field that was not
|
|
// explicitly supplied in the update. If COALESCE were removed from
|
|
// any column, the corresponding field would silently null out.
|
|
func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-coalesce-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-coalesce-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
rootChatID := uuid.New()
|
|
parentChatID := uuid.New()
|
|
|
|
// Insert a fully-populated run so every nullable field has a value.
|
|
original, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
|
|
ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true},
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: 42, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 41, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
Summary: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"key":"val"}`), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update only Status and FinishedAt. Every other nullable param
|
|
// is left as its Go zero value (Valid: false → SQL NULL), which
|
|
// the COALESCE pattern should interpret as "keep existing."
|
|
now := time.Now()
|
|
updated, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
|
ID: original.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
Now: now,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Status and FinishedAt should be updated.
|
|
require.Equal(t, "completed", updated.Status)
|
|
require.True(t, updated.FinishedAt.Valid)
|
|
|
|
// UpdatedAt should be set to the @now value we passed in.
|
|
require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond,
|
|
"updated_at should equal the @now parameter")
|
|
|
|
// Every field not in the update call must be preserved exactly.
|
|
require.Equal(t, original.RootChatID, updated.RootChatID,
|
|
"RootChatID should survive a partial update")
|
|
require.Equal(t, original.ParentChatID, updated.ParentChatID,
|
|
"ParentChatID should survive a partial update")
|
|
require.Equal(t, original.ModelConfigID, updated.ModelConfigID,
|
|
"ModelConfigID should survive a partial update")
|
|
require.Equal(t, original.TriggerMessageID, updated.TriggerMessageID,
|
|
"TriggerMessageID should survive a partial update")
|
|
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
|
"HistoryTipMessageID should survive a partial update")
|
|
require.Equal(t, original.Provider, updated.Provider,
|
|
"Provider should survive a partial update")
|
|
require.Equal(t, original.Model, updated.Model,
|
|
"Model should survive a partial update")
|
|
require.JSONEq(t, string(original.Summary), string(updated.Summary),
|
|
"Summary should survive a partial update")
|
|
require.Equal(t, original.Kind, updated.Kind,
|
|
"Kind should survive a partial update")
|
|
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
|
"StartedAt should survive a partial update")
|
|
}
|
|
|
|
// TestChatDebugStepCOALESCEPreservation verifies that the COALESCE
|
|
// pattern in UpdateChatDebugStep preserves every field that was not
|
|
// explicitly supplied in the update. If COALESCE were removed from
|
|
// any column, the corresponding field would silently null out.
|
|
func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-step-coalesce-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-step-coalesce-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
run, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a fully-populated step so every nullable field has a value.
|
|
original, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: run.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "llm_call",
|
|
Status: "in_progress",
|
|
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
|
AssistantMessageID: sql.NullInt64{Int64: 11, Valid: true},
|
|
NormalizedRequest: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"prompt":"hello"}`), Valid: true},
|
|
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"world"}`), Valid: true},
|
|
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":42}`), Valid: true},
|
|
Attempts: pqtype.NullRawMessage{RawMessage: json.RawMessage(`[{"n":1}]`), Valid: true},
|
|
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"transient"}`), Valid: true},
|
|
Metadata: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"trace_id":"abc"}`), Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Update only Status and FinishedAt. Every other nullable param
|
|
// is left as its Go zero value (Valid: false -> SQL NULL), which
|
|
// the COALESCE pattern should interpret as "keep existing."
|
|
now := time.Now()
|
|
updated, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
|
ID: original.ID,
|
|
ChatID: chat.ID,
|
|
Status: sql.NullString{String: "completed", Valid: true},
|
|
FinishedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
Now: now,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Status and FinishedAt should be updated.
|
|
require.Equal(t, "completed", updated.Status)
|
|
require.True(t, updated.FinishedAt.Valid)
|
|
|
|
// UpdatedAt should be set to the @now value we passed in.
|
|
require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond,
|
|
"updated_at should equal the @now parameter")
|
|
|
|
// Every field not in the update call must be preserved exactly.
|
|
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
|
"HistoryTipMessageID should survive a partial update")
|
|
require.Equal(t, original.AssistantMessageID, updated.AssistantMessageID,
|
|
"AssistantMessageID should survive a partial update")
|
|
require.JSONEq(t, string(original.NormalizedRequest), string(updated.NormalizedRequest),
|
|
"NormalizedRequest should survive a partial update")
|
|
require.JSONEq(t, string(original.NormalizedResponse.RawMessage), string(updated.NormalizedResponse.RawMessage),
|
|
"NormalizedResponse should survive a partial update")
|
|
require.JSONEq(t, string(original.Usage.RawMessage), string(updated.Usage.RawMessage),
|
|
"Usage should survive a partial update")
|
|
require.JSONEq(t, string(original.Attempts), string(updated.Attempts),
|
|
"Attempts should survive a partial update")
|
|
require.JSONEq(t, string(original.Error.RawMessage), string(updated.Error.RawMessage),
|
|
"Error should survive a partial update")
|
|
require.JSONEq(t, string(original.Metadata), string(updated.Metadata),
|
|
"Metadata should survive a partial update")
|
|
require.Equal(t, original.Operation, updated.Operation,
|
|
"Operation should survive a partial update")
|
|
require.Equal(t, original.StepNumber, updated.StepNumber,
|
|
"StepNumber should survive a partial update")
|
|
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
|
"StartedAt should survive a partial update")
|
|
}
|
|
|
|
// TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive verifies
|
|
// that runs whose message ID columns are all NULL are never matched
|
|
// by DeleteChatDebugDataAfterMessageID. SQL's three-valued logic
|
|
// means NULL > N evaluates to NULL (not TRUE), so these rows must
|
|
// survive. Without this test a future change could break the
|
|
// invariant with no test failure.
|
|
func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-null-msg-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-null-msg-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Insert a run with all message ID columns left as NULL (Valid: false).
|
|
nullMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
// TriggerMessageID and HistoryTipMessageID intentionally
|
|
// omitted (zero-value → SQL NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Attach a step with NULL message IDs too.
|
|
nullMsgStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
|
RunID: nullMsgRun.ID,
|
|
ChatID: chat.ID,
|
|
StepNumber: 1,
|
|
Operation: "stream",
|
|
Status: "in_progress",
|
|
// HistoryTipMessageID and AssistantMessageID intentionally
|
|
// omitted (zero-value → SQL NULL).
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Delete with an arbitrary cutoff. The run and its step should
|
|
// survive because NULL > cutoff evaluates to NULL, not TRUE.
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: 1,
|
|
StartedBefore: time.Now().Add(time.Minute),
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 0, deletedRows, "rows with NULL message IDs must not be deleted")
|
|
|
|
// Verify run still exists.
|
|
remaining, err := store.GetChatDebugRunByID(ctx, nullMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, nullMsgRun.ID, remaining.ID)
|
|
|
|
// Verify step still exists.
|
|
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, nullMsgRun.ID)
|
|
require.NoError(t, err)
|
|
require.Len(t, remainingSteps, 1)
|
|
require.Equal(t, nullMsgStep.ID, remainingSteps[0].ID)
|
|
}
|
|
|
|
// TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns
|
|
// verifies the started_before bound on DeleteChatDebugDataAfterMessageID.
|
|
// The bound exists so that retried cleanup (e.g. after edit or archive)
|
|
// cannot delete runs started by a replacement turn that races ahead of
|
|
// the retry window. Without this filter, a stale cleanup would wipe
|
|
// fresh debug rows.
|
|
func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-started-before-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-started-before-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
const cutoff int64 = 50
|
|
|
|
// oldRun started an hour ago: must be deleted because it started
|
|
// before the bound.
|
|
oldStartedAt := time.Now().Add(-1 * time.Hour).UTC().
|
|
Truncate(time.Microsecond)
|
|
oldRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Bound sits between the two runs. Any run whose started_at is at
|
|
// or after this instant must survive.
|
|
cutoffTime := time.Now().Add(-30 * time.Minute).UTC().
|
|
Truncate(time.Microsecond)
|
|
|
|
// newRun started after cutoffTime with identical message_id values
|
|
// that would otherwise match the delete predicate. It must survive
|
|
// because started_before excludes it.
|
|
newStartedAt := time.Now().UTC().Truncate(time.Microsecond)
|
|
newRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
TriggerMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 1, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
|
ChatID: chat.ID,
|
|
MessageID: cutoff,
|
|
StartedBefore: cutoffTime,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, deletedRows,
|
|
"only the pre-cutoff run should be deleted")
|
|
|
|
// oldRun must be gone.
|
|
_, err = store.GetChatDebugRunByID(ctx, oldRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
// newRun must survive the retry window.
|
|
remaining, err := store.GetChatDebugRunByID(ctx, newRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, newRun.ID, remaining.ID)
|
|
}
|
|
|
|
// TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns verifies
|
|
// the started_before bound on DeleteChatDebugDataByChatID. Archive
|
|
// cleanup retries rely on this bound to avoid deleting runs created
|
|
// by a replacement turn that starts after an unarchive races ahead of
|
|
// the retry window.
|
|
func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
|
|
providerName := "openai"
|
|
modelName := "debug-model-by-chat-started-before-" + uuid.NewString()
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: providerName,
|
|
DisplayName: "Debug Provider",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: providerName,
|
|
Model: modelName,
|
|
DisplayName: "Debug Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "chat-debug-by-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
oldStartedAt := time.Now().Add(-1 * time.Hour).UTC().
|
|
Truncate(time.Microsecond)
|
|
oldRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: oldStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
cutoffTime := time.Now().Add(-30 * time.Minute).UTC().
|
|
Truncate(time.Microsecond)
|
|
|
|
newStartedAt := time.Now().UTC().Truncate(time.Microsecond)
|
|
newRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
|
Kind: "chat_turn",
|
|
Status: "in_progress",
|
|
Provider: sql.NullString{String: providerName, Valid: true},
|
|
Model: sql.NullString{String: modelName, Valid: true},
|
|
StartedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
UpdatedAt: sql.NullTime{Time: newStartedAt, Valid: true},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
deletedRows, err := store.DeleteChatDebugDataByChatID(ctx, database.DeleteChatDebugDataByChatIDParams{
|
|
ChatID: chat.ID,
|
|
StartedBefore: cutoffTime,
|
|
})
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, deletedRows,
|
|
"only the pre-cutoff run should be deleted")
|
|
|
|
_, err = store.GetChatDebugRunByID(ctx, oldRun.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
|
|
remaining, err := store.GetChatDebugRunByID(ctx, newRun.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, newRun.ID, remaining.ID)
|
|
}
|
|
|
|
func TestChatHasUnread(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, _ := dbtestutil.NewDB(t)
|
|
ctx := context.Background()
|
|
|
|
org := dbgen.Organization(t, store, database.Organization{})
|
|
user := dbgen.User(t, store, database.User{})
|
|
dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
|
|
|
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
APIKey: "test-key",
|
|
Enabled: true,
|
|
CentralApiKeyEnabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model-" + uuid.NewString(),
|
|
DisplayName: "Test Model",
|
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
|
Enabled: true,
|
|
IsDefault: true,
|
|
ContextLimit: 128000,
|
|
CompressionThreshold: 80,
|
|
Options: json.RawMessage(`{}`),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
|
OrganizationID: org.ID,
|
|
Status: database.ChatStatusWaiting,
|
|
ClientType: database.ChatClientTypeUi,
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "test-chat-" + uuid.NewString(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
getHasUnread := func() bool {
|
|
rows, err := store.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: user.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
for _, row := range rows {
|
|
if row.Chat.ID == chat.ID {
|
|
return row.HasUnread
|
|
}
|
|
}
|
|
t.Fatal("chat not found in GetChats result")
|
|
return false
|
|
}
|
|
|
|
// New chat with no messages: not unread.
|
|
require.False(t, getHasUnread(), "new chat with no messages should not be unread")
|
|
|
|
// Helper to insert a single chat message.
|
|
insertMsg := func(role database.ChatMessageRole, text string) {
|
|
t.Helper()
|
|
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{user.ID},
|
|
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
|
Role: []database.ChatMessageRole{role},
|
|
Content: []string{fmt.Sprintf(`[{"type":"text","text":%q}]`, text)},
|
|
ContentVersion: []int16{0},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
|
InputTokens: []int64{0},
|
|
OutputTokens: []int64{0},
|
|
TotalTokens: []int64{0},
|
|
ReasoningTokens: []int64{0},
|
|
CacheCreationTokens: []int64{0},
|
|
CacheReadTokens: []int64{0},
|
|
ContextLimit: []int64{0},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{0},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Insert an assistant message: becomes unread.
|
|
insertMsg(database.ChatMessageRoleAssistant, "hello")
|
|
require.True(t, getHasUnread(), "chat with unread assistant message should be unread")
|
|
|
|
// Mark as read: no longer unread.
|
|
lastMsg, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
|
ChatID: chat.ID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
require.NoError(t, err)
|
|
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
|
ID: chat.ID,
|
|
LastReadMessageID: lastMsg.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, getHasUnread(), "chat should not be unread after marking as read")
|
|
|
|
// Insert another assistant message: becomes unread again.
|
|
insertMsg(database.ChatMessageRoleAssistant, "new message")
|
|
require.True(t, getHasUnread(), "new assistant message after read should be unread")
|
|
|
|
// Mark as read again, then verify user messages don't
|
|
// trigger unread.
|
|
lastMsg, err = store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
|
ChatID: chat.ID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
require.NoError(t, err)
|
|
err = store.UpdateChatLastReadMessageID(ctx, database.UpdateChatLastReadMessageIDParams{
|
|
ID: chat.ID,
|
|
LastReadMessageID: lastMsg.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
insertMsg(database.ChatMessageRoleUser, "user msg")
|
|
require.False(t, getHasUnread(), "user messages should not trigger unread")
|
|
}
|