From a1d51f0dab76f33b4db025e9760486c865e2aea3 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 3 Apr 2026 15:47:26 -0500 Subject: [PATCH] feat: batch connection logs to avoid DB lock contention (#23727) - Running 30k connections was generating a ton of lock contention in the DB --- coderd/agentapi/connectionlog.go | 2 +- coderd/agentapi/connectionlog_test.go | 2 +- coderd/connectionlog/connectionlog.go | 4 +- coderd/database/dbauthz/dbauthz.go | 14 +- coderd/database/dbauthz/dbauthz_test.go | 7 +- coderd/database/dbgen/dbgen.go | 50 +- coderd/database/dbmetrics/querymetrics.go | 16 +- coderd/database/dbmock/dbmock.go | 29 +- coderd/database/modelmethods.go | 26 + coderd/database/querier.go | 2 +- coderd/database/querier_test.go | 679 +++++++++++++----- coderd/database/queries.sql.go | 231 +++--- coderd/database/queries/connectionlogs.sql | 118 +-- coderd/workspaceapps/db.go | 2 +- coderd/workspaceapps/db_test.go | 2 +- enterprise/coderd/coderd.go | 12 +- .../coderd/connectionlog/connectionlog.go | 490 ++++++++++++- .../connectionlog_internal_test.go | 529 ++++++++++++++ .../connectionlog/connectionlog_test.go | 371 ++++++++++ enterprise/coderd/connectionlog_test.go | 2 +- enterprise/coderd/workspaceproxy_test.go | 6 +- 21 files changed, 2168 insertions(+), 426 deletions(-) create mode 100644 enterprise/coderd/connectionlog/connectionlog_internal_test.go create mode 100644 enterprise/coderd/connectionlog/connectionlog_test.go diff --git a/coderd/agentapi/connectionlog.go b/coderd/agentapi/connectionlog.go index 420d788153..b033a1d8ae 100644 --- a/coderd/agentapi/connectionlog.go +++ b/coderd/agentapi/connectionlog.go @@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor AgentName: a.AgentName, Type: connectionType, Code: code, - Ip: logIP, + IP: logIP, ConnectionID: uuid.NullUUID{ UUID: connectionID, Valid: true, diff --git a/coderd/agentapi/connectionlog_test.go b/coderd/agentapi/connectionlog_test.go index 68d910e2e8..94bd223d30 100644 --- a/coderd/agentapi/connectionlog_test.go +++ b/coderd/agentapi/connectionlog_test.go @@ -152,7 +152,7 @@ func TestConnectionLog(t *testing.T) { Int32: tt.status, Valid: *tt.action == agentproto.Connection_DISCONNECT, }, - Ip: expectedIP, + IP: expectedIP, Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ), DisconnectReason: sql.NullString{ String: tt.reason, diff --git a/coderd/connectionlog/connectionlog.go b/coderd/connectionlog/connectionlog.go index b3d9e9115f..582bcf9c03 100644 --- a/coderd/connectionlog/connectionlog.go +++ b/coderd/connectionlog/connectionlog.go @@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32) continue } - if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() { - t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet) + if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() { + t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet) continue } if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6afdd213cd..03a09965bf 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1627,6 +1627,13 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg) } +func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil { + return err + } + return q.db.BatchUpsertConnectionLogs(ctx, arg) +} + func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil { return 0, err @@ -7065,13 +7072,6 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl) } -func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil { - return database.ConnectionLog{}, err - } - return q.db.UpsertConnectionLog(ctx, arg) -} - func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index edadb8f985..26c2cbb654 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -338,10 +338,9 @@ func (s *MethodTestSuite) TestAuditLogs() { } func (s *MethodTestSuite) TestConnectionLogs() { - s.Run("UpsertConnectionLog", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - ws := testutil.Fake(s.T(), faker, database.WorkspaceTable{}) - arg := database.UpsertConnectionLogParams{Ip: defaultIPAddress(), Type: database.ConnectionTypeSsh, WorkspaceID: ws.ID, OrganizationID: ws.OrganizationID, ConnectionStatus: database.ConnectionStatusConnected, WorkspaceOwnerID: ws.OwnerID} - dbm.EXPECT().UpsertConnectionLog(gomock.Any(), arg).Return(database.ConnectionLog{}, nil).AnyTimes() + s.Run("BatchUpsertConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.BatchUpsertConnectionLogsParams{} + dbm.EXPECT().BatchUpsertConnectionLogs(gomock.Any(), arg).Return(nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate) })) s.Run("GetConnectionLogsOffset", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 1dce62c841..69a50ede9c 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -76,7 +76,7 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database. } func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog { - log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{ + arg := database.UpsertConnectionLogParams{ ID: takeFirst(seed.ID, uuid.New()), Time: takeFirst(seed.Time, dbtime.Now()), OrganizationID: takeFirst(seed.OrganizationID, uuid.New()), @@ -89,7 +89,7 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti Int32: takeFirst(seed.Code.Int32, 0), Valid: takeFirst(seed.Code.Valid, false), }, - Ip: pqtype.Inet{ + IP: pqtype.Inet{ IPNet: net.IPNet{ IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 255, 255, 255), @@ -117,9 +117,53 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti Valid: takeFirst(seed.DisconnectReason.Valid, false), }, ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected), + } + + var disconnectTime sql.NullTime + if arg.ConnectionStatus == database.ConnectionStatusDisconnected { + disconnectTime = sql.NullTime{Time: arg.Time, Valid: true} + } + + err := db.BatchUpsertConnectionLogs(genCtx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{arg.ID}, + ConnectTime: []time.Time{arg.Time}, + OrganizationID: []uuid.UUID{arg.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{arg.WorkspaceOwnerID}, + WorkspaceID: []uuid.UUID{arg.WorkspaceID}, + WorkspaceName: []string{arg.WorkspaceName}, + AgentName: []string{arg.AgentName}, + Type: []database.ConnectionType{arg.Type}, + Code: []int32{arg.Code.Int32}, + CodeValid: []bool{arg.Code.Valid}, + Ip: []pqtype.Inet{arg.IP}, + UserAgent: []string{arg.UserAgent.String}, + UserID: []uuid.UUID{arg.UserID.UUID}, + SlugOrPort: []string{arg.SlugOrPort.String}, + ConnectionID: []uuid.UUID{arg.ConnectionID.UUID}, + DisconnectReason: []string{arg.DisconnectReason.String}, + DisconnectTime: []time.Time{disconnectTime.Time}, }) require.NoError(t, err, "insert connection log") - return log + + // Query back the actual row from the database. On upsert + // conflict the DB keeps the original row's ID, so we can't + // rely on arg.ID. Match on the conflict key for rows with a + // connection_id, or by primary key for NULL connection_id. + rows, err := db.GetConnectionLogsOffset(genCtx, database.GetConnectionLogsOffsetParams{}) + require.NoError(t, err, "query connection logs") + for _, row := range rows { + if arg.ConnectionID.Valid { + if row.ConnectionLog.ConnectionID == arg.ConnectionID && + row.ConnectionLog.WorkspaceID == arg.WorkspaceID && + row.ConnectionLog.AgentName == arg.AgentName { + return row.ConnectionLog + } + } else if row.ConnectionLog.ID == arg.ID { + return row.ConnectionLog + } + } + require.Failf(t, "connection log not found", "id=%s", arg.ID) + return database.ConnectionLog{} // unreachable } func Template(t testing.TB, db database.Store, seed database.Template) database.Template { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index a611530dc4..1243a9138b 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -208,6 +208,14 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context, return r0 } +func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error { + start := time.Now() + r0 := m.s.BatchUpsertConnectionLogs(ctx, arg) + m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc() + return r0 +} + func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) { start := time.Now() r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg) @@ -5024,14 +5032,6 @@ func (m queryMetricsStore) UpsertChatWorkspaceTTL(ctx context.Context, workspace return r0 } -func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { - start := time.Now() - r0, r1 := m.s.UpsertConnectionLog(ctx, arg) - m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertConnectionLog").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error { start := time.Now() r0 := m.s.UpsertDefaultProxy(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d329557e0e..6d3a07699f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -233,6 +233,20 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg) } +// BatchUpsertConnectionLogs mocks base method. +func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs. +func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg) +} + // BulkMarkNotificationMessagesFailed mocks base method. func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) { m.ctrl.T.Helper() @@ -9442,21 +9456,6 @@ func (mr *MockStoreMockRecorder) UpsertChatWorkspaceTTL(ctx, workspaceTtl any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatWorkspaceTTL", reflect.TypeOf((*MockStore)(nil).UpsertChatWorkspaceTTL), ctx, workspaceTtl) } -// UpsertConnectionLog mocks base method. -func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg) - ret0, _ := ret[0].(database.ConnectionLog) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpsertConnectionLog indicates an expected call of UpsertConnectionLog. -func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg) -} - // UpsertDefaultProxy mocks base method. func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 3cf68883c3..147161f03c 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "golang.org/x/exp/maps" "golang.org/x/oauth2" "golang.org/x/xerrors" @@ -923,3 +924,28 @@ func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity { func (r GetWorkspaceAgentAndWorkspaceByIDRow) RBACObject() rbac.Object { return r.WorkspaceTable.RBACObject() } + +// UpsertConnectionLogParams contains the parameters for upserting a +// connection log entry. This struct is hand-maintained (not generated +// by sqlc) because the single-row UpsertConnectionLog query was +// removed in favor of BatchUpsertConnectionLogs, but the struct is +// still used as the canonical connection log event type throughout +// the codebase. +type UpsertConnectionLogParams struct { + ID uuid.UUID `db:"id" json:"id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + WorkspaceName string `db:"workspace_name" json:"workspace_name"` + AgentName string `db:"agent_name" json:"agent_name"` + Type ConnectionType `db:"type" json:"type"` + Code sql.NullInt32 `db:"code" json:"code"` + IP pqtype.Inet `db:"ip" json:"ip"` + UserAgent sql.NullString `db:"user_agent" json:"user_agent"` + UserID uuid.NullUUID `db:"user_id" json:"user_id"` + SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"` + ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"` + DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"` + Time time.Time `db:"time" json:"time"` + ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"` +} diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9714b74999..53072af425 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -65,6 +65,7 @@ type sqlcQuerier interface { BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error + BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error) BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error) // Calculates the telemetry summary for a given provider, model, and client @@ -991,7 +992,6 @@ type sqlcQuerier interface { UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error) UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl string) error - UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) // The default proxy is implied and not actually stored in the database. // So we need to store it's configuration here for display purposes. // The functional values are immutable and controlled implicitly. diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index db31e88c69..dd68a4ce9c 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -3566,9 +3566,11 @@ func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffs return ids } -func TestUpsertConnectionLog(t *testing.T) { +func TestBatchUpsertConnectionLogs(t *testing.T) { t.Parallel() + createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable { + t.Helper() u := dbgen.User(t, db, database.User{}) o := dbgen.Organization(t, db, database.Organization{}) tpl := dbgen.Template(t, db, database.Template{ @@ -3584,253 +3586,536 @@ func TestUpsertConnectionLog(t *testing.T) { }) } + // zeroTime is the sentinel value that the SQL treats as "no + // connect/disconnect time provided". + zeroTime := time.Time{} + + defaultIP := pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + } + + t.Run("SingleConnect", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + connID := uuid.New() + connectTime := dbtime.Now() + + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{connectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime)) + require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid, + "disconnect_time should be NULL for a connect-only event") + }) + t.Run("ConnectThenDisconnect", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := context.Background() - ws := createWorkspace(t, db) - - connectionID := uuid.New() - agentName := "test-agent" - - // 1. Insert a 'connect' event. + connID := uuid.New() connectTime := dbtime.Now() - connectParams := database.UpsertConnectionLogParams{ - ID: uuid.New(), - Time: connectTime, - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceID: ws.ID, - WorkspaceName: ws.Name, - AgentName: agentName, - Type: database.ConnectionTypeSsh, - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - ConnectionStatus: database.ConnectionStatusConnected, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 255), - }, - Valid: true, - }, - } - log1, err := db.UpsertConnectionLog(ctx, connectParams) + // Insert connect. + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{connectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) + + // Insert disconnect for same connection. + disconnectTime := connectTime.Add(time.Second) + err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{zeroTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{1}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{"test disconnect"}, + DisconnectTime: []time.Time{disconnectTime}, + }) require.NoError(t, err) - require.Equal(t, connectParams.ID, log1.ID) - require.False(t, log1.DisconnectTime.Valid, "DisconnectTime should not be set on connect") - // Check that one row exists. rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) require.NoError(t, err) require.Len(t, rows, 1) - - // 2. Insert a 'disconnected' event for the same connection. - disconnectTime := connectTime.Add(time.Second) - disconnectParams := database.UpsertConnectionLogParams{ - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - WorkspaceID: ws.ID, - AgentName: agentName, - ConnectionStatus: database.ConnectionStatusDisconnected, - - // Updated to: - Time: disconnectTime, - DisconnectReason: sql.NullString{String: "test disconnect", Valid: true}, - Code: sql.NullInt32{Int32: 1, Valid: true}, - - // Ignored - ID: uuid.New(), - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceName: ws.Name, - Type: database.ConnectionTypeSsh, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 254), - }, - Valid: true, - }, - } - - log2, err := db.UpsertConnectionLog(ctx, disconnectParams) - require.NoError(t, err) - - // Updated - require.Equal(t, log1.ID, log2.ID) - require.True(t, log2.DisconnectTime.Valid) - require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time)) - require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String) - - rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) - require.NoError(t, err) - require.Len(t, rows, 1) + row := rows[0].ConnectionLog + require.True(t, connectTime.Equal(row.ConnectTime)) + require.True(t, row.DisconnectTime.Valid) + require.True(t, disconnectTime.Equal(row.DisconnectTime.Time)) + require.Equal(t, "test disconnect", row.DisconnectReason.String) + require.Equal(t, int32(1), row.Code.Int32) }) - t.Run("ConnectDoesNotUpdate", func(t *testing.T) { + t.Run("DuplicateConnectIsNoOp", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := context.Background() - ws := createWorkspace(t, db) - - connectionID := uuid.New() - agentName := "test-agent" - - // 1. Insert a 'connect' event. + connID := uuid.New() connectTime := dbtime.Now() - connectParams := database.UpsertConnectionLogParams{ - ID: uuid.New(), - Time: connectTime, - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceID: ws.ID, - WorkspaceName: ws.Name, - AgentName: agentName, - Type: database.ConnectionTypeSsh, - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - ConnectionStatus: database.ConnectionStatusConnected, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 255), - }, - Valid: true, - }, + + mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams { + return database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{ct}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{ip}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + } } - log, err := db.UpsertConnectionLog(ctx, connectParams) + err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP)) require.NoError(t, err) - // 2. Insert another 'connect' event for the same connection. - connectTime2 := connectTime.Add(time.Second) - connectParams2 := database.UpsertConnectionLogParams{ - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - WorkspaceID: ws.ID, - AgentName: agentName, - ConnectionStatus: database.ConnectionStatusConnected, + rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows1, 1) - // Ignored - ID: uuid.New(), - Time: connectTime2, - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceName: ws.Name, - Type: database.ConnectionTypeSsh, - Code: sql.NullInt32{Int32: 0, Valid: false}, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 254), - }, - Valid: true, + // Second connect with later time and different IP. + otherIP := pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(10, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), }, + Valid: true, } - - origLog, err := db.UpsertConnectionLog(ctx, connectParams2) + err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP)) require.NoError(t, err) - require.Equal(t, log, origLog, "connect update should be a no-op") - // Check that still only one row exists. - rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) require.NoError(t, err) - require.Len(t, rows, 1) - require.Equal(t, log, rows[0].ConnectionLog) + require.Len(t, rows2, 1) + + // The LEAST logic should pick the earlier connect_time; IP and + // other fields are not updated on conflict. + require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime), + "connect_time should remain the original (earlier) value") }) - t.Run("DisconnectThenConnect", func(t *testing.T) { + t.Run("OrderIndependentConnectTime", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) ctx := context.Background() - ws := createWorkspace(t, db) - - connectionID := uuid.New() - agentName := "test-agent" - - // Insert just a 'disconect' event + connID := uuid.New() disconnectTime := dbtime.Now() - disconnectParams := database.UpsertConnectionLogParams{ - ID: uuid.New(), - Time: disconnectTime, - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceID: ws.ID, - WorkspaceName: ws.Name, - AgentName: agentName, - Type: database.ConnectionTypeSsh, - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - ConnectionStatus: database.ConnectionStatusDisconnected, - DisconnectReason: sql.NullString{String: "server shutting down", Valid: true}, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 255), - }, - Valid: true, - }, + connectTime := disconnectTime.Add(-5 * time.Second) + + // Disconnect arrives first. + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{disconnectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{"bye"}, + DisconnectTime: []time.Time{disconnectTime}, + }) + require.NoError(t, err) + + // Connect arrives second with the real (earlier) connect_time. + err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{connectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime), + "LEAST should pick the earlier connect_time") + }) + + t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + connID := uuid.New() + disconnectTime := dbtime.Now() + + mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams { + return database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{disconnectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{code}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{reason}, + DisconnectTime: []time.Time{disconnectTime}, + } } - _, err := db.UpsertConnectionLog(ctx, disconnectParams) + err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1)) require.NoError(t, err) - firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + // Second disconnect with different reason and code. + err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2)) require.NoError(t, err) - require.Len(t, firstRows, 1) - // We expect the connection event to be marked as closed with the start - // and close time being the same. - require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid) - require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC()) - require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC()) + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + row := rows[0].ConnectionLog + require.Equal(t, "first reason", row.DisconnectReason.String, + "disconnect_reason should not be overwritten") + require.Equal(t, int32(1), row.Code.Int32, + "code should not be overwritten") + }) - // Now insert a 'connect' event for the same connection. - // This should be a no op - connectTime := disconnectTime.Add(time.Second) - connectParams := database.UpsertConnectionLogParams{ - ID: uuid.New(), - Time: connectTime, - OrganizationID: ws.OrganizationID, - WorkspaceOwnerID: ws.OwnerID, - WorkspaceID: ws.ID, - WorkspaceName: ws.Name, - AgentName: agentName, - Type: database.ConnectionTypeSsh, - ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, - ConnectionStatus: database.ConnectionStatusConnected, - DisconnectReason: sql.NullString{String: "reconnected", Valid: true}, - Code: sql.NullInt32{Int32: 0, Valid: false}, - Ip: pqtype.Inet{ - IPNet: net.IPNet{ - IP: net.IPv4(127, 0, 0, 1), - Mask: net.IPv4Mask(255, 255, 255, 255), - }, - Valid: true, - }, + t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + connID := uuid.New() + disconnectTime := dbtime.Now() + + // Insert disconnect first. + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{disconnectTime}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{42}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{"server shutdown"}, + DisconnectTime: []time.Time{disconnectTime}, + }) + require.NoError(t, err) + + rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows1, 1) + require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid) + require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String) + require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32) + + // Insert connect for same connection_id. + err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{disconnectTime.Add(time.Second)}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) + + rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows2, 1) + row := rows2[0].ConnectionLog + require.True(t, row.DisconnectTime.Valid, + "disconnect_time should not be cleared by a later connect") + require.Equal(t, "server shutdown", row.DisconnectReason.String, + "disconnect_reason should not be cleared") + require.Equal(t, int32(42), row.Code.Int32, + "code should not be cleared") + }) + + t.Run("CodeZeroPreserved", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + connID := uuid.New() + now := dbtime.Now() + + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{now}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{0}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{"normal"}, + DisconnectTime: []time.Time{now}, + }) + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL") + require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32, + "code=0 should be preserved, not treated as NULL") + }) + + t.Run("CodeNullWhenInvalid", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + connID := uuid.New() + now := dbtime.Now() + + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{now}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{99}, + CodeValid: []bool{false}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{""}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{""}, + ConnectionID: []uuid.UUID{connID}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.False(t, rows[0].ConnectionLog.Code.Valid, + "code should be NULL when code_valid is false") + }) + + t.Run("NullConnectionIDEvents", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + now := dbtime.Now() + + // Insert two web events with NULL connection_id (uuid.Nil → + // NULL via NULLIF) for the same workspace/agent. + for i := range 2 { + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: []uuid.UUID{uuid.New()}, + ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)}, + OrganizationID: []uuid.UUID{ws.OrganizationID}, + WorkspaceOwnerID: []uuid.UUID{ws.OwnerID}, + WorkspaceID: []uuid.UUID{ws.ID}, + WorkspaceName: []string{ws.Name}, + AgentName: []string{"agent"}, + Type: []database.ConnectionType{database.ConnectionTypeSsh}, + Code: []int32{200}, + CodeValid: []bool{true}, + Ip: []pqtype.Inet{defaultIP}, + UserAgent: []string{"Mozilla/5.0"}, + UserID: []uuid.UUID{uuid.Nil}, + SlugOrPort: []string{"web-terminal"}, + ConnectionID: []uuid.UUID{uuid.Nil}, + DisconnectReason: []string{""}, + DisconnectTime: []time.Time{zeroTime}, + }) + require.NoError(t, err) } - _, err = db.UpsertConnectionLog(ctx, connectParams) + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) require.NoError(t, err) + require.Len(t, rows, 2, + "NULL connection_id rows should not conflict with each other") + }) - secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) - require.NoError(t, err) - require.Len(t, secondRows, 1) - require.Equal(t, firstRows, secondRows) + t.Run("MultipleIndependentConnections", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + ws := createWorkspace(t, db) + now := dbtime.Now() - // Upsert a disconnection, which should also be a no op - disconnectParams.DisconnectReason = sql.NullString{ - String: "updated close reason", - Valid: true, + n := 5 + ids := make([]uuid.UUID, n) + connectTimes := make([]time.Time, n) + orgIDs := make([]uuid.UUID, n) + ownerIDs := make([]uuid.UUID, n) + wsIDs := make([]uuid.UUID, n) + wsNames := make([]string, n) + agentNames := make([]string, n) + types := make([]database.ConnectionType, n) + codes := make([]int32, n) + codeValids := make([]bool, n) + ips := make([]pqtype.Inet, n) + userAgents := make([]string, n) + userIDs := make([]uuid.UUID, n) + slugOrPorts := make([]string, n) + connIDs := make([]uuid.UUID, n) + disconnectReasons := make([]string, n) + disconnectTimes := make([]time.Time, n) + + for i := range n { + ids[i] = uuid.New() + connectTimes[i] = now.Add(time.Duration(i) * time.Second) + orgIDs[i] = ws.OrganizationID + ownerIDs[i] = ws.OwnerID + wsIDs[i] = ws.ID + wsNames[i] = ws.Name + agentNames[i] = "agent" + types[i] = database.ConnectionTypeSsh + codes[i] = 0 + codeValids[i] = false + ips[i] = defaultIP + userAgents[i] = "" + userIDs[i] = uuid.Nil + slugOrPorts[i] = "" + connIDs[i] = uuid.New() + disconnectReasons[i] = "" + disconnectTimes[i] = zeroTime } - _, err = db.UpsertConnectionLog(ctx, disconnectParams) + + err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{ + ID: ids, + ConnectTime: connectTimes, + OrganizationID: orgIDs, + WorkspaceOwnerID: ownerIDs, + WorkspaceID: wsIDs, + WorkspaceName: wsNames, + AgentName: agentNames, + Type: types, + Code: codes, + CodeValid: codeValids, + Ip: ips, + UserAgent: userAgents, + UserID: userIDs, + SlugOrPort: slugOrPorts, + ConnectionID: connIDs, + DisconnectReason: disconnectReasons, + DisconnectTime: disconnectTimes, + }) require.NoError(t, err) - thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{}) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10}) require.NoError(t, err) - require.Len(t, secondRows, 1) - // The close reason shouldn't be updated - require.Equal(t, secondRows, thirdRows) + require.Len(t, rows, n, "each unique connection_id should produce its own row") }) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index abc9f6b2a1..a57c3771ef 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7338,6 +7338,123 @@ func (q *sqlQuerier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg U return i, err } +const batchUpsertConnectionLogs = `-- name: BatchUpsertConnectionLogs :exec +INSERT INTO connection_logs ( + id, connect_time, organization_id, workspace_owner_id, workspace_id, + workspace_name, agent_name, type, code, ip, user_agent, user_id, + slug_or_port, connection_id, disconnect_reason, disconnect_time +) +SELECT + u.id, + u.connect_time, + u.organization_id, + u.workspace_owner_id, + u.workspace_id, + u.workspace_name, + u.agent_name, + u.type, + -- Use the validity flag to distinguish "no code" (NULL) from a + -- legitimate zero exit code. + CASE WHEN u.code_valid THEN u.code ELSE NULL END, + u.ip, + NULLIF(u.user_agent, ''), + NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid), + NULLIF(u.slug_or_port, ''), + NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid), + NULLIF(u.disconnect_reason, ''), + NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz) +FROM ( + SELECT + unnest($1::uuid[]) AS id, + unnest($2::timestamptz[]) AS connect_time, + unnest($3::uuid[]) AS organization_id, + unnest($4::uuid[]) AS workspace_owner_id, + unnest($5::uuid[]) AS workspace_id, + unnest($6::text[]) AS workspace_name, + unnest($7::text[]) AS agent_name, + unnest($8::connection_type[]) AS type, + unnest($9::int4[]) AS code, + unnest($10::bool[]) AS code_valid, + unnest($11::inet[]) AS ip, + unnest($12::text[]) AS user_agent, + unnest($13::uuid[]) AS user_id, + unnest($14::text[]) AS slug_or_port, + unnest($15::uuid[]) AS connection_id, + unnest($16::text[]) AS disconnect_reason, + unnest($17::timestamptz[]) AS disconnect_time +) AS u +ON CONFLICT (connection_id, workspace_id, agent_name) +DO UPDATE SET + -- Pick the earliest real connect_time. The zero sentinel + -- ('0001-01-01') means the batch didn't know the connect_time + -- (e.g. a pure disconnect event), so we keep the existing value. + connect_time = CASE + WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz + THEN connection_logs.connect_time + WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz + THEN EXCLUDED.connect_time + ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time) + END, + disconnect_time = CASE + WHEN connection_logs.disconnect_time IS NULL + THEN EXCLUDED.disconnect_time + ELSE connection_logs.disconnect_time + END, + disconnect_reason = CASE + WHEN connection_logs.disconnect_reason IS NULL + THEN EXCLUDED.disconnect_reason + ELSE connection_logs.disconnect_reason + END, + code = CASE + WHEN connection_logs.code IS NULL + THEN EXCLUDED.code + ELSE connection_logs.code + END +` + +type BatchUpsertConnectionLogsParams struct { + ID []uuid.UUID `db:"id" json:"id"` + ConnectTime []time.Time `db:"connect_time" json:"connect_time"` + OrganizationID []uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwnerID []uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceID []uuid.UUID `db:"workspace_id" json:"workspace_id"` + WorkspaceName []string `db:"workspace_name" json:"workspace_name"` + AgentName []string `db:"agent_name" json:"agent_name"` + Type []ConnectionType `db:"type" json:"type"` + Code []int32 `db:"code" json:"code"` + CodeValid []bool `db:"code_valid" json:"code_valid"` + Ip []pqtype.Inet `db:"ip" json:"ip"` + UserAgent []string `db:"user_agent" json:"user_agent"` + UserID []uuid.UUID `db:"user_id" json:"user_id"` + SlugOrPort []string `db:"slug_or_port" json:"slug_or_port"` + ConnectionID []uuid.UUID `db:"connection_id" json:"connection_id"` + DisconnectReason []string `db:"disconnect_reason" json:"disconnect_reason"` + DisconnectTime []time.Time `db:"disconnect_time" json:"disconnect_time"` +} + +func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error { + _, err := q.db.ExecContext(ctx, batchUpsertConnectionLogs, + pq.Array(arg.ID), + pq.Array(arg.ConnectTime), + pq.Array(arg.OrganizationID), + pq.Array(arg.WorkspaceOwnerID), + pq.Array(arg.WorkspaceID), + pq.Array(arg.WorkspaceName), + pq.Array(arg.AgentName), + pq.Array(arg.Type), + pq.Array(arg.Code), + pq.Array(arg.CodeValid), + pq.Array(arg.Ip), + pq.Array(arg.UserAgent), + pq.Array(arg.UserID), + pq.Array(arg.SlugOrPort), + pq.Array(arg.ConnectionID), + pq.Array(arg.DisconnectReason), + pq.Array(arg.DisconnectTime), + ) + return err +} + const countConnectionLogs = `-- name: CountConnectionLogs :one SELECT COUNT(*) AS count @@ -7753,120 +7870,6 @@ func (q *sqlQuerier) GetConnectionLogsOffset(ctx context.Context, arg GetConnect return items, nil } -const upsertConnectionLog = `-- name: UpsertConnectionLog :one -INSERT INTO connection_logs ( - id, - connect_time, - organization_id, - workspace_owner_id, - workspace_id, - workspace_name, - agent_name, - type, - code, - ip, - user_agent, - user_id, - slug_or_port, - connection_id, - disconnect_reason, - disconnect_time -) VALUES - ($1, $15, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, - -- If we've only received a disconnect event, mark the event as immediately - -- closed. - CASE - WHEN $16::connection_status = 'disconnected' - THEN $15 :: timestamp with time zone - ELSE NULL - END) -ON CONFLICT (connection_id, workspace_id, agent_name) -DO UPDATE SET - -- No-op if the connection is still open. - disconnect_time = CASE - WHEN $16::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.disconnect_time IS NULL - THEN EXCLUDED.connect_time - ELSE connection_logs.disconnect_time - END, - disconnect_reason = CASE - WHEN $16::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.disconnect_reason IS NULL - THEN EXCLUDED.disconnect_reason - ELSE connection_logs.disconnect_reason - END, - code = CASE - WHEN $16::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.code IS NULL - THEN EXCLUDED.code - ELSE connection_logs.code - END -RETURNING id, connect_time, organization_id, workspace_owner_id, workspace_id, workspace_name, agent_name, type, ip, code, user_agent, user_id, slug_or_port, connection_id, disconnect_time, disconnect_reason -` - -type UpsertConnectionLogParams struct { - ID uuid.UUID `db:"id" json:"id"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` - WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` - WorkspaceName string `db:"workspace_name" json:"workspace_name"` - AgentName string `db:"agent_name" json:"agent_name"` - Type ConnectionType `db:"type" json:"type"` - Code sql.NullInt32 `db:"code" json:"code"` - Ip pqtype.Inet `db:"ip" json:"ip"` - UserAgent sql.NullString `db:"user_agent" json:"user_agent"` - UserID uuid.NullUUID `db:"user_id" json:"user_id"` - SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"` - ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"` - DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"` - Time time.Time `db:"time" json:"time"` - ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"` -} - -func (q *sqlQuerier) UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) { - row := q.db.QueryRowContext(ctx, upsertConnectionLog, - arg.ID, - arg.OrganizationID, - arg.WorkspaceOwnerID, - arg.WorkspaceID, - arg.WorkspaceName, - arg.AgentName, - arg.Type, - arg.Code, - arg.Ip, - arg.UserAgent, - arg.UserID, - arg.SlugOrPort, - arg.ConnectionID, - arg.DisconnectReason, - arg.Time, - arg.ConnectionStatus, - ) - var i ConnectionLog - err := row.Scan( - &i.ID, - &i.ConnectTime, - &i.OrganizationID, - &i.WorkspaceOwnerID, - &i.WorkspaceID, - &i.WorkspaceName, - &i.AgentName, - &i.Type, - &i.Ip, - &i.Code, - &i.UserAgent, - &i.UserID, - &i.SlugOrPort, - &i.ConnectionID, - &i.DisconnectTime, - &i.DisconnectReason, - ) - return i, err -} - const deleteCryptoKey = `-- name: DeleteCryptoKey :one UPDATE crypto_keys SET secret = NULL, secret_key_id = NULL diff --git a/coderd/database/queries/connectionlogs.sql b/coderd/database/queries/connectionlogs.sql index fc38d1af1a..63e0023dcc 100644 --- a/coderd/database/queries/connectionlogs.sql +++ b/coderd/database/queries/connectionlogs.sql @@ -251,55 +251,75 @@ DELETE FROM connection_logs USING old_logs WHERE connection_logs.id = old_logs.id; --- name: UpsertConnectionLog :one +-- name: BatchUpsertConnectionLogs :exec INSERT INTO connection_logs ( - id, - connect_time, - organization_id, - workspace_owner_id, - workspace_id, - workspace_name, - agent_name, - type, - code, - ip, - user_agent, - user_id, - slug_or_port, - connection_id, - disconnect_reason, - disconnect_time -) VALUES - ($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, - -- If we've only received a disconnect event, mark the event as immediately - -- closed. - CASE - WHEN @connection_status::connection_status = 'disconnected' - THEN @time :: timestamp with time zone - ELSE NULL - END) + id, connect_time, organization_id, workspace_owner_id, workspace_id, + workspace_name, agent_name, type, code, ip, user_agent, user_id, + slug_or_port, connection_id, disconnect_reason, disconnect_time +) +SELECT + u.id, + u.connect_time, + u.organization_id, + u.workspace_owner_id, + u.workspace_id, + u.workspace_name, + u.agent_name, + u.type, + -- Use the validity flag to distinguish "no code" (NULL) from a + -- legitimate zero exit code. + CASE WHEN u.code_valid THEN u.code ELSE NULL END, + u.ip, + NULLIF(u.user_agent, ''), + NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid), + NULLIF(u.slug_or_port, ''), + NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid), + NULLIF(u.disconnect_reason, ''), + NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz) +FROM ( + SELECT + unnest(sqlc.arg('id')::uuid[]) AS id, + unnest(sqlc.arg('connect_time')::timestamptz[]) AS connect_time, + unnest(sqlc.arg('organization_id')::uuid[]) AS organization_id, + unnest(sqlc.arg('workspace_owner_id')::uuid[]) AS workspace_owner_id, + unnest(sqlc.arg('workspace_id')::uuid[]) AS workspace_id, + unnest(sqlc.arg('workspace_name')::text[]) AS workspace_name, + unnest(sqlc.arg('agent_name')::text[]) AS agent_name, + unnest(sqlc.arg('type')::connection_type[]) AS type, + unnest(sqlc.arg('code')::int4[]) AS code, + unnest(sqlc.arg('code_valid')::bool[]) AS code_valid, + unnest(sqlc.arg('ip')::inet[]) AS ip, + unnest(sqlc.arg('user_agent')::text[]) AS user_agent, + unnest(sqlc.arg('user_id')::uuid[]) AS user_id, + unnest(sqlc.arg('slug_or_port')::text[]) AS slug_or_port, + unnest(sqlc.arg('connection_id')::uuid[]) AS connection_id, + unnest(sqlc.arg('disconnect_reason')::text[]) AS disconnect_reason, + unnest(sqlc.arg('disconnect_time')::timestamptz[]) AS disconnect_time +) AS u ON CONFLICT (connection_id, workspace_id, agent_name) DO UPDATE SET - -- No-op if the connection is still open. - disconnect_time = CASE - WHEN @connection_status::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.disconnect_time IS NULL - THEN EXCLUDED.connect_time - ELSE connection_logs.disconnect_time - END, - disconnect_reason = CASE - WHEN @connection_status::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.disconnect_reason IS NULL - THEN EXCLUDED.disconnect_reason - ELSE connection_logs.disconnect_reason - END, - code = CASE - WHEN @connection_status::connection_status = 'disconnected' - -- Can only be set once - AND connection_logs.code IS NULL - THEN EXCLUDED.code - ELSE connection_logs.code - END -RETURNING *; + -- Pick the earliest real connect_time. The zero sentinel + -- ('0001-01-01') means the batch didn't know the connect_time + -- (e.g. a pure disconnect event), so we keep the existing value. + connect_time = CASE + WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz + THEN connection_logs.connect_time + WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz + THEN EXCLUDED.connect_time + ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time) + END, + disconnect_time = CASE + WHEN connection_logs.disconnect_time IS NULL + THEN EXCLUDED.disconnect_time + ELSE connection_logs.disconnect_time + END, + disconnect_reason = CASE + WHEN connection_logs.disconnect_reason IS NULL + THEN EXCLUDED.disconnect_reason + ELSE connection_logs.disconnect_reason + END, + code = CASE + WHEN connection_logs.code IS NULL + THEN EXCLUDED.code + ELSE connection_logs.code + END; diff --git a/coderd/workspaceapps/db.go b/coderd/workspaceapps/db.go index 08b3cd8426..02a4859cca 100644 --- a/coderd/workspaceapps/db.go +++ b/coderd/workspaceapps/db.go @@ -535,7 +535,7 @@ func (p *DBTokenProvider) connLogInitRequest(w http.ResponseWriter, r *http.Requ Int32: statusCode, Valid: true, }, - Ip: database.ParseIP(ip), + IP: database.ParseIP(ip), UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent}, UserID: uuid.NullUUID{ UUID: userID, diff --git a/coderd/workspaceapps/db_test.go b/coderd/workspaceapps/db_test.go index 5d5370661f..d59160f1b5 100644 --- a/coderd/workspaceapps/db_test.go +++ b/coderd/workspaceapps/db_test.go @@ -1281,7 +1281,7 @@ func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http. WorkspaceName: workspace.Name, AgentName: agentName, Type: typ, - Ip: database.ParseIP(r.RemoteAddr), + IP: database.ParseIP(r.RemoteAddr), UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()}, Code: sql.NullInt32{ Int32: int32(resp.StatusCode), // nolint:gosec diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 070d496dba..f2bb7dbdfc 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "crypto/tls" "fmt" + "io" "math" "net/http" "net/url" @@ -144,10 +145,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } if options.ConnectionLogger == nil { - options.ConnectionLogger = connectionlog.NewConnectionLogger( - connectionlog.NewDBBackend(options.Database), + connLogger := connectionlog.New( + connectionlog.NewDBBatcher(ctx, options.Database, options.Logger), connectionlog.NewSlogBackend(options.Logger), ) + options.ConnectionLogger = connLogger } meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates) @@ -822,6 +824,12 @@ func (api *API) Close() error { api.Options.CheckInactiveUsersCancelFunc() } + // Close the connection logger to flush any remaining batched + // entries before shutting down the database connection. + if cl, ok := api.Options.ConnectionLogger.(io.Closer); ok { + _ = cl.Close() + } + return api.AGPL.Close() } diff --git a/enterprise/coderd/connectionlog/connectionlog.go b/enterprise/coderd/connectionlog/connectionlog.go index 4b24ba402c..9cd36e6505 100644 --- a/enterprise/coderd/connectionlog/connectionlog.go +++ b/enterprise/coderd/connectionlog/connectionlog.go @@ -2,31 +2,70 @@ package connectionlog import ( "context" + "io" + "sync" + "time" + "github.com/google/uuid" "github.com/hashicorp/go-multierror" + "github.com/sqlc-dev/pqtype" "cdr.dev/slog/v3" - agpl "github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" auditbackends "github.com/coder/coder/v2/enterprise/audit/backends" + "github.com/coder/quartz" ) +const ( + // defaultBatchSize is the maximum number of connection log entries + // to batch before forcing a flush. + defaultBatchSize = 1000 + + // defaultFlushInterval is how frequently to flush batched connection + // log entries to the database. Five seconds balances near-real-time + // audit visibility with write efficiency. + defaultFlushInterval = 5 * time.Second + + // retryQueueSize is the capacity of the bounded retry channel. + // Failed batches beyond this limit are dropped. + retryQueueSize = 10 + + // shutdownWriteTimeout bounds how long a final write attempt + // can take during shutdown when the batcher context is already + // canceled. + shutdownWriteTimeout = 10 * time.Second + + // maxRetries is the number of times to retry a failed batch + // write before dropping it and moving on. + maxRetries = 3 + + // retryInterval is the fixed delay between retry attempts. + retryInterval = time.Second +) + +// Backend is a destination for connection log events. Backends that +// also implement io.Closer will be closed when the ConnectionLogger +// is closed. type Backend interface { Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error } -func NewConnectionLogger(backends ...Backend) agpl.ConnectionLogger { - return &connectionLogger{ +// ConnectionLogger fans out each connection log event to every +// registered backend. +type ConnectionLogger struct { + backends []Backend +} + +// New creates a ConnectionLogger that dispatches to the given +// backends. +func New(backends ...Backend) *ConnectionLogger { + return &ConnectionLogger{ backends: backends, } } -type connectionLogger struct { - backends []Backend -} - -func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { +func (c *ConnectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { var errs error for _, backend := range c.backends { err := backend.Upsert(ctx, clog) @@ -37,24 +76,443 @@ func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConne return errs } -type dbBackend struct { - db database.Store +// Close closes all backends that implement io.Closer. +func (c *ConnectionLogger) Close() error { + var errs error + for _, backend := range c.backends { + if closer, ok := backend.(io.Closer); ok { + if err := closer.Close(); err != nil { + errs = multierror.Append(errs, err) + } + } + } + return errs } -func NewDBBackend(db database.Store) Backend { - return &dbBackend{db: db} +// DBBatcherOption is a functional option for configuring a DBBatcher. +type DBBatcherOption func(b *DBBatcher) + +// WithBatchSize sets the maximum number of entries to accumulate +// before forcing a flush. +func WithBatchSize(size int) DBBatcherOption { + return func(b *DBBatcher) { + b.maxBatchSize = size + } } -func (b *dbBackend) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { - //nolint:gocritic // This is the Connection Logger - _, err := b.db.UpsertConnectionLog(dbauthz.AsConnectionLogger(ctx), clog) - return err +// WithFlushInterval sets how frequently the batcher flushes to the +// database. +func WithFlushInterval(d time.Duration) DBBatcherOption { + return func(b *DBBatcher) { + b.interval = d + } +} + +// WithClock sets the clock, useful for testing. +func WithClock(clock quartz.Clock) DBBatcherOption { + return func(b *DBBatcher) { + b.clock = clock + } +} + +// DBBatcher batches connection log upserts and periodically flushes +// them to the database to reduce per-event write pressure. +type DBBatcher struct { + store database.Store + log slog.Logger + + itemCh chan database.UpsertConnectionLogParams + + // dedupedBatch holds entries keyed by connection ID so that + // PostgreSQL never sees the same row twice in one INSERT … + // ON CONFLICT DO UPDATE. Connection IDs are globally unique + // (each new session gets a fresh UUID). Entries with a NULL + // connection_id (web events) go into nullConnIDBatch instead + // because NULL != NULL in SQL unique constraints. + dedupedBatch map[uuid.UUID]batchEntry + nullConnIDBatch []batchEntry + maxBatchSize int + + // retryCh is a bounded channel of failed batches awaiting + // retry. A single retry worker goroutine processes this + // channel, retrying each batch up to maxRetries times before + // dropping it. If the channel is full, new failures are + // dropped immediately. + retryCh chan database.BatchUpsertConnectionLogsParams + + clock quartz.Clock + timer *quartz.Timer + interval time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewDBBatcher creates a DBBatcher that batches writes to the database +// and starts its background processing loop. Close must be called to +// flush remaining entries on shutdown. +func NewDBBatcher(ctx context.Context, store database.Store, log slog.Logger, opts ...DBBatcherOption) *DBBatcher { + b := &DBBatcher{ + store: store, + log: log, + clock: quartz.NewReal(), + } + + for _, opt := range opts { + opt(b) + } + + if b.interval == 0 { + b.interval = defaultFlushInterval + } + if b.maxBatchSize == 0 { + b.maxBatchSize = defaultBatchSize + } + + b.timer = b.clock.NewTimer(b.interval) + b.itemCh = make(chan database.UpsertConnectionLogParams, b.maxBatchSize) + b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize) + b.retryCh = make(chan database.BatchUpsertConnectionLogsParams, retryQueueSize) + + b.ctx, b.cancel = context.WithCancel(ctx) + b.wg.Add(2) + go func() { + defer b.wg.Done() + b.run(b.ctx) + }() + go func() { + defer b.wg.Done() + b.retryLoop() + }() + + return b +} + +// Upsert enqueues a connection log entry for batched writing. It +// blocks if the internal buffer is full, ensuring no logs are dropped. +// It returns an error if the batcher or caller context is canceled. +func (b *DBBatcher) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error { + if b.ctx.Err() != nil { + return b.ctx.Err() + } + + select { + case b.itemCh <- clog: + return nil + case <-b.ctx.Done(): + return b.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close cancels the batcher context, waits for the run loop and +// retry worker to exit. +func (b *DBBatcher) Close() error { + b.cancel() + if b.timer != nil { + b.timer.Stop() + } + b.wg.Wait() + return nil +} + +// addToBatch inserts an item into the batch, deduplicating by conflict +// key on the fly. For entries with the same key, disconnect events are +// preferred over connect events, and later events are preferred over +// earlier ones. +// +// This is safe because each new connection gets a fresh UUID (see +// agent/agent.go and agent/agentssh), so the only duplicate for the +// same (connection_id, workspace_id, agent_name) is a connect/disconnect +// pair for the same session. A "reconnect" always uses a new ID. +func (b *DBBatcher) addToBatch(item database.UpsertConnectionLogParams) { + entry := batchEntry{ + UpsertConnectionLogParams: item, + } + if item.ConnectionStatus == database.ConnectionStatusDisconnected { + // For standalone disconnect events, use the disconnect + // time as both connect and disconnect time. This matches + // the single-row UpsertConnectionLog behavior which uses + // @time for connect_time regardless of status. The SQL + // LEAST logic will correct connect_time if the real + // connect event arrives in a later batch. + entry.connectTime = item.Time + entry.disconnectTime = item.Time + } else { + entry.connectTime = item.Time + } + + if !item.ConnectionID.Valid { + b.nullConnIDBatch = append(b.nullConnIDBatch, entry) + return + } + connID := item.ConnectionID.UUID + existing, ok := b.dedupedBatch[connID] + if !ok { + b.dedupedBatch[connID] = entry + return + } + // When merging entries for the same connection, always preserve + // the earliest non-zero connect_time and latest disconnect_time + // so the row records the full session span. + if !existing.connectTime.IsZero() && existing.connectTime.Before(entry.connectTime) { + entry.connectTime = existing.connectTime + } + if existing.disconnectTime.After(entry.disconnectTime) { + entry.disconnectTime = existing.disconnectTime + } + + // Prefer disconnect over connect (superset of info). + // If same status, prefer the later event. + if item.ConnectionStatus == database.ConnectionStatusDisconnected && + existing.ConnectionStatus != database.ConnectionStatusDisconnected { + b.dedupedBatch[connID] = entry + } else if item.Time.After(existing.Time) { + b.dedupedBatch[connID] = entry + } +} + +// batchLen returns the total number of entries currently buffered. +func (b *DBBatcher) batchLen() int { + return len(b.dedupedBatch) + len(b.nullConnIDBatch) +} + +func (b *DBBatcher) run(ctx context.Context) { + //nolint:gocritic // System-level batch operation for connection logs. + authCtx := dbauthz.AsConnectionLogger(ctx) + for ctx.Err() == nil { + select { + case item := <-b.itemCh: + b.addToBatch(item) + + if b.batchLen() >= b.maxBatchSize { + b.flush(authCtx) + b.timer.Reset(b.interval, "connectionLogBatcher", "capacityFlush") + } + + case <-b.timer.C: + b.flush(authCtx) + b.timer.Reset(b.interval, "connectionLogBatcher", "scheduledFlush") + + case <-ctx.Done(): + } + } + + b.log.Debug(ctx, "context done, flushing before exit") + + // Drain any remaining items from the channel. + for { + select { + case item := <-b.itemCh: + b.addToBatch(item) + default: + if b.batchLen() > 0 { + b.shutdownBatch(b.buildParams()) + } + // Signal the retry worker to skip delays and close + // the channel so it exits after processing any + // remaining items. + // Mark the batcher as closed so that any subsequent + // Upsert calls fail immediately instead of sending + // into itemCh after the run loop has exited. + close(b.retryCh) + return + } + } +} + +// batchEntry wraps a connection log event with explicit connect and +// disconnect times. When a connect and disconnect for the same session +// are merged into one entry, connectTime preserves the original +// session start while disconnectTime records when it ended. +type batchEntry struct { + database.UpsertConnectionLogParams + connectTime time.Time + disconnectTime time.Time +} + +// flush builds the batch params, clears the in-memory batch, and +// writes to the database. On failure, the batch is queued for retry +// by the single retry worker goroutine. If the retry queue is full, +// the batch is dropped. +func (b *DBBatcher) flush(ctx context.Context) { + count := b.batchLen() + if count == 0 { + return + } + + params := b.buildParams() + + // Clear the batch before writing so the run loop can start + // accumulating new entries. + b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize) + b.nullConnIDBatch = nil + + // Use the batcher's context for normal operation so Close() + // can cancel hung writes. During shutdown (ctx already canceled), + // fall back to a bounded timeout. + writeCtx := b.ctx + if writeCtx.Err() != nil { + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(context.Background(), shutdownWriteTimeout) + defer cancel() + } + //nolint:gocritic // System-level batch operation for connection logs. + err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(writeCtx), params) + if err == nil { + return + } + + b.log.Error(ctx, "batch upsert failed, queueing for retry", + slog.Error(err), slog.F("count", count)) + + // Don't retry on shutdown. + if ctx.Err() != nil { + return + } + + select { + case b.retryCh <- params: + default: + b.log.Error(ctx, "retry queue full, dropping batch", + slog.F("dropped", count)) + } +} + +func (b *DBBatcher) buildParams() database.BatchUpsertConnectionLogsParams { + count := b.batchLen() + var ( + ids = make([]uuid.UUID, 0, count) + connectTime = make([]time.Time, 0, count) + organizationID = make([]uuid.UUID, 0, count) + workspaceOwnerID = make([]uuid.UUID, 0, count) + workspaceID = make([]uuid.UUID, 0, count) + workspaceName = make([]string, 0, count) + agentName = make([]string, 0, count) + connType = make([]database.ConnectionType, 0, count) + code = make([]int32, 0, count) + codeValid = make([]bool, 0, count) + ip = make([]pqtype.Inet, 0, count) + userAgent = make([]string, 0, count) + userID = make([]uuid.UUID, 0, count) + slugOrPort = make([]string, 0, count) + connectionID = make([]uuid.UUID, 0, count) + disconnectReason = make([]string, 0, count) + disconnectTime = make([]time.Time, 0, count) + ) + + appendEntry := func(e batchEntry) { + ids = append(ids, e.ID) + connectTime = append(connectTime, e.connectTime) + organizationID = append(organizationID, e.OrganizationID) + workspaceOwnerID = append(workspaceOwnerID, e.WorkspaceOwnerID) + workspaceID = append(workspaceID, e.WorkspaceID) + workspaceName = append(workspaceName, e.WorkspaceName) + agentName = append(agentName, e.AgentName) + connType = append(connType, e.Type) + code = append(code, e.Code.Int32) + codeValid = append(codeValid, e.Code.Valid) + ip = append(ip, e.IP) + userAgent = append(userAgent, e.UserAgent.String) + userID = append(userID, e.UserID.UUID) + slugOrPort = append(slugOrPort, e.SlugOrPort.String) + connectionID = append(connectionID, e.ConnectionID.UUID) + disconnectReason = append(disconnectReason, e.DisconnectReason.String) + disconnectTime = append(disconnectTime, e.disconnectTime) + } + + for _, entry := range b.dedupedBatch { + appendEntry(entry) + } + for _, entry := range b.nullConnIDBatch { + appendEntry(entry) + } + + return database.BatchUpsertConnectionLogsParams{ + ID: ids, + ConnectTime: connectTime, + OrganizationID: organizationID, + WorkspaceOwnerID: workspaceOwnerID, + WorkspaceID: workspaceID, + WorkspaceName: workspaceName, + AgentName: agentName, + Type: connType, + Code: code, + CodeValid: codeValid, + Ip: ip, + UserAgent: userAgent, + UserID: userID, + SlugOrPort: slugOrPort, + ConnectionID: connectionID, + DisconnectReason: disconnectReason, + DisconnectTime: disconnectTime, + } +} + +// retryLoop is a single background goroutine that processes failed +// batches from retryCh. Each batch is retried up to maxRetries times +// with a fixed delay between attempts. When draining is set (shutdown), +// batches get a single immediate write attempt instead. The loop exits +// when retryCh is closed by the run goroutine. +func (b *DBBatcher) retryLoop() { + for params := range b.retryCh { + b.retryBatch(params) + } +} + +// retryBatch retries writing a batch up to maxRetries times with a +// fixed delay between attempts. If the batcher context is canceled +// during a wait, one final attempt is made before returning. +func (b *DBBatcher) retryBatch(params database.BatchUpsertConnectionLogsParams) { + count := len(params.ID) + for attempt := range maxRetries { + t := time.NewTimer(retryInterval) + select { + case <-b.ctx.Done(): + b.shutdownBatch(params) + return + case <-t.C: + } + + //nolint:gocritic // System-level batch operation for connection logs. + err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(b.ctx), params) + if err == nil { + return + } + + b.log.Warn(b.ctx, "batch retry failed", + slog.Error(err), + slog.F("count", count), + slog.F("attempt", attempt+1), + slog.F("max_attempts", maxRetries), + ) + } + + b.log.Error(b.ctx, "batch retries exhausted, dropping batch", + slog.F("dropped", count)) +} + +// shutdownBatch makes a single write attempt during shutdown with a +// bounded timeout so it can't hang indefinitely. +func (b *DBBatcher) shutdownBatch(params database.BatchUpsertConnectionLogsParams) { + ctx, cancel := context.WithTimeout(context.Background(), shutdownWriteTimeout) + defer cancel() + //nolint:gocritic // System-level batch operation for connection logs. + err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(ctx), params) + if err != nil { + b.log.Error(b.ctx, "batch write failed on shutdown, dropping batch", + slog.Error(err), slog.F("dropped", len(params.ID))) + } } type connectionSlogBackend struct { exporter *auditbackends.SlogExporter } +// NewSlogBackend returns a Backend that logs connection events via +// the structured logger. func NewSlogBackend(logger slog.Logger) Backend { return &connectionSlogBackend{ exporter: auditbackends.NewSlogExporter(logger), diff --git a/enterprise/coderd/connectionlog/connectionlog_internal_test.go b/enterprise/coderd/connectionlog/connectionlog_internal_test.go new file mode 100644 index 0000000000..5804e5ae3f --- /dev/null +++ b/enterprise/coderd/connectionlog/connectionlog_internal_test.go @@ -0,0 +1,529 @@ +package connectionlog + +import ( + "context" + "database/sql" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func Test_addToBatch(t *testing.T) { + t.Parallel() + + t.Run("ConnectThenDisconnect", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + connID := uuid.New() + + connect := fakeConnectEvent(wsID, "agent1", connID) + disconnect := fakeDisconnectEvent(wsID, "agent1", connID) + + b.addToBatch(connect) + b.addToBatch(disconnect) + + require.Equal(t, 1, b.batchLen()) + key := connID + got := b.dedupedBatch[key] + require.Equal(t, disconnect.ID, got.ID) + require.Equal(t, database.ConnectionStatusDisconnected, got.ConnectionStatus) + // The connect_time should be preserved from the original + // connect event, not overwritten by the disconnect's + // timestamp. + require.Equal(t, connect.Time, got.connectTime) + require.Equal(t, disconnect.Time, got.disconnectTime) + }) + + t.Run("DisconnectThenLaterConnect", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + connID := uuid.New() + + disconnect := fakeDisconnectEvent(wsID, "agent1", connID) + connect := fakeConnectEvent(wsID, "agent1", connID) + connect.Time = disconnect.Time.Add(time.Second) + + b.addToBatch(disconnect) + b.addToBatch(connect) + + require.Equal(t, 1, b.batchLen()) + key := connID + // The later event wins when the incoming item is not a + // disconnect. In practice, this case doesn't occur because + // connection IDs are never reused. + got := b.dedupedBatch[key] + require.Equal(t, connect.ID, got.ID) + // The disconnect's time should be preserved even though + // the connect event replaced it. + require.Equal(t, disconnect.Time, got.disconnectTime) + }) + + t.Run("DisconnectThenEarlierConnect", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + connID := uuid.New() + + disconnect := fakeDisconnectEvent(wsID, "agent1", connID) + connect := fakeConnectEvent(wsID, "agent1", connID) + connect.Time = disconnect.Time.Add(-time.Second) + + b.addToBatch(disconnect) + b.addToBatch(connect) + + require.Equal(t, 1, b.batchLen()) + key := connID + require.Equal(t, disconnect.ID, b.dedupedBatch[key].ID) + }) + + t.Run("SameStatusKeepsLater", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + connID := uuid.New() + + early := fakeConnectEvent(wsID, "agent1", connID) + early.Time = time.Now() + late := fakeConnectEvent(wsID, "agent1", connID) + late.Time = early.Time.Add(time.Second) + + b.addToBatch(early) + b.addToBatch(late) + + require.Equal(t, 1, b.batchLen()) + key := connID + require.Equal(t, late.ID, b.dedupedBatch[key].ID) + }) + + t.Run("NullConnIDsNeverDedup", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + evt1 := fakeNullConnIDEvent() + evt2 := fakeNullConnIDEvent() + evt2.WorkspaceID = evt1.WorkspaceID + evt2.AgentName = evt1.AgentName + + b.addToBatch(evt1) + b.addToBatch(evt2) + + require.Equal(t, 2, b.batchLen()) + require.Len(t, b.nullConnIDBatch, 2) + require.Empty(t, b.dedupedBatch) + }) + + t.Run("MixedNullAndNonNull", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + regular := fakeConnectEvent(wsID, "agent1", uuid.New()) + nullEvt := fakeNullConnIDEvent() + nullEvt.WorkspaceID = wsID + nullEvt.AgentName = "agent1" + + b.addToBatch(regular) + b.addToBatch(nullEvt) + + require.Equal(t, 2, b.batchLen()) + require.Len(t, b.dedupedBatch, 1) + require.Len(t, b.nullConnIDBatch, 1) + }) + + t.Run("StandaloneDisconnectUsesTimeAsConnectTime", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + connID := uuid.New() + disconnect := fakeDisconnectEvent(uuid.New(), "agent1", connID) + + b.addToBatch(disconnect) + + got := b.dedupedBatch[connID] + // A standalone disconnect must not leave connectTime as + // zero — that would insert a year-0001 connect_time in + // the DB. It should use the disconnect's own timestamp, + // matching the single-row UpsertConnectionLog behavior. + require.False(t, got.connectTime.IsZero(), + "standalone disconnect must have non-zero connectTime") + require.Equal(t, disconnect.Time, got.connectTime) + require.Equal(t, disconnect.Time, got.disconnectTime) + }) + + t.Run("DuplicateDisconnectsPreserveConnectTime", func(t *testing.T) { + t.Parallel() + + b := &DBBatcher{ + maxBatchSize: 100, + dedupedBatch: make(map[uuid.UUID]batchEntry), + } + + wsID := uuid.New() + connID := uuid.New() + + connect := fakeConnectEvent(wsID, "agent1", connID) + disconnect1 := fakeDisconnectEvent(wsID, "agent1", connID) + disconnect2 := fakeDisconnectEvent(wsID, "agent1", connID) + disconnect2.Time = disconnect1.Time.Add(time.Second) + + b.addToBatch(connect) + b.addToBatch(disconnect1) + b.addToBatch(disconnect2) + + require.Equal(t, 1, b.batchLen()) + got := b.dedupedBatch[connID] + // The second disconnect should win (later event) but the + // original connect_time from the connect event must be + // preserved, not regressed to the disconnect's timestamp. + require.Equal(t, disconnect2.ID, got.ID) + require.Equal(t, connect.Time, got.connectTime, + "connect_time must not regress to disconnect timestamp") + require.Equal(t, disconnect2.Time, got.disconnectTime) + }) +} + +func Test_batcherFlush(t *testing.T) { + t.Parallel() + + t.Run("DeduplicatesConnectDisconnect", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + wsID := uuid.New() + connID := uuid.New() + connect := fakeConnectEvent(wsID, "agent1", connID) + disconnect := fakeDisconnectEvent(wsID, "agent1", connID) + + // Expect a single batch with only the disconnect event. + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{ + expectedCount: 1, + mustContainIDs: []uuid.UUID{disconnect.ID}, + mustNotContainIDs: []uuid.UUID{connect.ID}, + }). + Return(nil). + Times(1) + + require.NoError(t, b.Upsert(ctx, connect)) + require.NoError(t, b.Upsert(ctx, disconnect)) + require.NoError(t, b.Close()) + }) + + t.Run("DoesNotDeduplicateNullConnIDs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + evt1 := fakeNullConnIDEvent() + evt2 := fakeNullConnIDEvent() + evt2.WorkspaceID = evt1.WorkspaceID + evt2.AgentName = evt1.AgentName + + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{ + expectedCount: 2, + mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID}, + }). + Return(nil). + Times(1) + + require.NoError(t, b.Upsert(ctx, evt1)) + require.NoError(t, b.Upsert(ctx, evt2)) + require.NoError(t, b.Close()) + }) + + t.Run("DoesNotDeduplicateDifferentConnectionIDs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + wsID := uuid.New() + evt1 := fakeConnectEvent(wsID, "agent1", uuid.New()) + evt2 := fakeConnectEvent(wsID, "agent1", uuid.New()) + + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{ + expectedCount: 2, + mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID}, + }). + Return(nil). + Times(1) + + require.NoError(t, b.Upsert(ctx, evt1)) + require.NoError(t, b.Upsert(ctx, evt2)) + require.NoError(t, b.Close()) + }) + + t.Run("CloseFlushesMultipleEvents", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + evt1 := fakeConnectEvent(uuid.New(), "agent1", uuid.New()) + evt2 := fakeConnectEvent(uuid.New(), "agent2", uuid.New()) + + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{ + expectedCount: 2, + mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID}, + }). + Return(nil). + Times(1) + + require.NoError(t, b.Upsert(ctx, evt1)) + require.NoError(t, b.Upsert(ctx, evt2)) + require.NoError(t, b.Close()) + }) + + t.Run("RetriesOnTransientFailure", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush") + defer scheduledTrap.Close() + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New()) + + // First call (synchronous in flush) fails, then the + // retry worker retries after the backoff and succeeds. + gomock.InOrder( + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()). + Return(xerrors.New("transient error")). + Times(1), + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{ + expectedCount: 1, + mustContainIDs: []uuid.UUID{evt.ID}, + }). + Return(nil). + Times(1), + ) + + require.NoError(t, b.Upsert(ctx, evt)) + + // Trigger a scheduled flush while the batcher is still + // running. The synchronous write fails and queues to + // retryCh. The retry worker picks it up after a real- + // time 1s delay and succeeds. + clock.Advance(defaultFlushInterval).MustWait(ctx) + scheduledTrap.MustWait(ctx).MustRelease(ctx) + + // Wait for the retry to complete (real-time 1s delay). + require.NoError(t, b.Close()) + }) + + t.Run("ShutdownDrainsRetryQueue", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + clock := quartz.NewMock(t) + + scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush") + defer scheduledTrap.Close() + + b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100)) + + evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New()) + + // Track all successfully written IDs. + var writtenIDs []uuid.UUID + var mu sync.Mutex + firstCall := true + store.EXPECT(). + BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, p database.BatchUpsertConnectionLogsParams) error { + mu.Lock() + defer mu.Unlock() + // First call (synchronous flush) fails, queueing + // the batch for retry. + if firstCall { + firstCall = false + return xerrors.New("transient error") + } + // Drain/retry attempts succeed. + writtenIDs = append(writtenIDs, p.ID...) + return nil + }). + AnyTimes() + + // Send event and trigger flush — fails, queues. + require.NoError(t, b.Upsert(ctx, evt)) + clock.Advance(defaultFlushInterval).MustWait(ctx) + scheduledTrap.MustWait(ctx).MustRelease(ctx) + + // Close triggers shutdown. The retry worker drains + // retryCh and writes the batch via writeBatch. + require.NoError(t, b.Close()) + + mu.Lock() + defer mu.Unlock() + require.Contains(t, writtenIDs, evt.ID, + "event should be written during shutdown drain") + }) +} + +// batchParamsMatcher validates BatchUpsertConnectionLogsParams by +// checking count and specific IDs. +type batchParamsMatcher struct { + expectedCount int + mustContainIDs []uuid.UUID + mustNotContainIDs []uuid.UUID +} + +func (m batchParamsMatcher) Matches(x interface{}) bool { + params, ok := x.(database.BatchUpsertConnectionLogsParams) + if !ok { + return false + } + if m.expectedCount > 0 && len(params.ID) != m.expectedCount { + return false + } + idSet := make(map[uuid.UUID]struct{}, len(params.ID)) + for _, id := range params.ID { + idSet[id] = struct{}{} + } + for _, id := range m.mustContainIDs { + if _, ok := idSet[id]; !ok { + return false + } + } + for _, id := range m.mustNotContainIDs { + if _, ok := idSet[id]; ok { + return false + } + } + return true +} + +func (batchParamsMatcher) String() string { + return "batch upsert params matcher" +} + +func fakeConnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams { + return database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: time.Now(), + OrganizationID: uuid.New(), + WorkspaceOwnerID: uuid.New(), + WorkspaceID: workspaceID, + WorkspaceName: "test-workspace", + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + } +} + +func fakeDisconnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams { + return database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: time.Now().Add(time.Second), + OrganizationID: uuid.New(), + WorkspaceOwnerID: uuid.New(), + WorkspaceID: workspaceID, + WorkspaceName: "test-workspace", + AgentName: agentName, + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true}, + ConnectionStatus: database.ConnectionStatusDisconnected, + Code: sql.NullInt32{Int32: 0, Valid: true}, + DisconnectReason: sql.NullString{String: "normal", Valid: true}, + } +} + +func fakeNullConnIDEvent() database.UpsertConnectionLogParams { + return database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: time.Now(), + OrganizationID: uuid.New(), + WorkspaceOwnerID: uuid.New(), + WorkspaceID: uuid.New(), + WorkspaceName: "test-workspace", + AgentName: "test-agent", + Type: database.ConnectionTypeWorkspaceApp, + ConnectionID: uuid.NullUUID{}, + ConnectionStatus: database.ConnectionStatusConnected, + } +} diff --git a/enterprise/coderd/connectionlog/connectionlog_test.go b/enterprise/coderd/connectionlog/connectionlog_test.go new file mode 100644 index 0000000000..416bec7885 --- /dev/null +++ b/enterprise/coderd/connectionlog/connectionlog_test.go @@ -0,0 +1,371 @@ +package connectionlog_test + +import ( + "database/sql" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/enterprise/coderd/connectionlog" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func createWorkspace(t *testing.T, db database.Store) database.WorkspaceTable { + t.Helper() + u := dbgen.User(t, db, database.User{}) + o := dbgen.Organization(t, db, database.Organization{}) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: o.ID, + CreatedBy: u.ID, + }) + return dbgen.Workspace(t, db, database.WorkspaceTable{ + ID: uuid.New(), + OwnerID: u.ID, + OrganizationID: o.ID, + AutomaticUpdates: database.AutomaticUpdatesNever, + TemplateID: tpl.ID, + }) +} + +func testIP() pqtype.Inet { + return pqtype.Inet{ + IPNet: net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + Valid: true, + } +} + +func TestDBBackendIntegration(t *testing.T) { + t.Parallel() + + t.Run("SingleConnect", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + //nolint:gocritic // Test needs system context for the batcher. + backend := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + + connID := uuid.New() + connectTime := dbtime.Now() + err := backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + + err = backend.Close() + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, connID, rows[0].ConnectionLog.ConnectionID.UUID) + require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid) + }) + + t.Run("ConnectThenDisconnectSeparateBatches", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + connID := uuid.New() + connectTime := dbtime.Now() + + // First batcher: insert connect, close to flush. + //nolint:gocritic // Test needs system context for the batcher. + b1 := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + err := b1.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + require.NoError(t, b1.Close()) + + // Second batcher: insert disconnect, close to flush. + //nolint:gocritic // Test needs system context for the batcher. + b2 := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + disconnectTime := connectTime.Add(5 * time.Second) + err = b2.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: disconnectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connID, Valid: true}, + ConnectionStatus: database.ConnectionStatusDisconnected, + Code: sql.NullInt32{Int32: 0, Valid: true}, + DisconnectReason: sql.NullString{String: "client left", Valid: true}, + IP: testIP(), + }) + require.NoError(t, err) + require.NoError(t, b2.Close()) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 1, "connect+disconnect should produce one row") + require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid) + require.Equal(t, "client left", rows[0].ConnectionLog.DisconnectReason.String) + }) + + t.Run("ConnectAndDisconnectSameBatch", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + //nolint:gocritic // Test needs system context for the batcher. + backend := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + + connID := uuid.New() + connectTime := dbtime.Now() + disconnectTime := connectTime.Add(time.Second) + + // Both events in the same batch window. + err := backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: connectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connID, Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + + err = backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: disconnectTime, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: connID, Valid: true}, + ConnectionStatus: database.ConnectionStatusDisconnected, + Code: sql.NullInt32{Int32: 0, Valid: true}, + DisconnectReason: sql.NullString{String: "done", Valid: true}, + IP: testIP(), + }) + require.NoError(t, err) + + // Close drains channel and flushes — dedup keeps disconnect. + err = backend.Close() + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid) + require.Equal(t, "done", rows[0].ConnectionLog.DisconnectReason.String) + }) + + t.Run("MultipleIndependentConnections", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + //nolint:gocritic // Test needs system context for the batcher. + backend := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + + now := dbtime.Now() + for i := 0; i < 5; i++ { + err := backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: now, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + } + + err := backend.Close() + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 5) + }) + + t.Run("NullConnectionIDWebEvents", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + //nolint:gocritic // Test needs system context for the batcher. + backend := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + + now := dbtime.Now() + for i := 0; i < 2; i++ { + err := backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: now, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeWorkspaceApp, + ConnectionID: uuid.NullUUID{}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + } + + err := backend.Close() + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 2, "null connection_id events should not be deduplicated") + }) + + t.Run("CloseFlushesToDB", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + clock := quartz.NewMock(t) + + ws := createWorkspace(t, db) + + //nolint:gocritic // Test needs system context for the batcher. + backend := connectionlog.NewDBBatcher( + dbauthz.AsConnectionLogger(ctx), db, log, + connectionlog.WithClock(clock), + connectionlog.WithBatchSize(100), + ) + + err := backend.Upsert(ctx, database.UpsertConnectionLogParams{ + ID: uuid.New(), + Time: dbtime.Now(), + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + WorkspaceID: ws.ID, + WorkspaceName: ws.Name, + AgentName: "main", + Type: database.ConnectionTypeSsh, + ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + ConnectionStatus: database.ConnectionStatusConnected, + IP: testIP(), + }) + require.NoError(t, err) + + // Close without advancing clock — final flush should write. + err = backend.Close() + require.NoError(t, err) + + rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{ + LimitOpt: 10, + }) + require.NoError(t, err) + require.Len(t, rows, 1) + }) +} diff --git a/enterprise/coderd/connectionlog_test.go b/enterprise/coderd/connectionlog_test.go index 59ff1b780e..fc7a0ea902 100644 --- a/enterprise/coderd/connectionlog_test.go +++ b/enterprise/coderd/connectionlog_test.go @@ -227,7 +227,7 @@ func TestConnectionLogs(t *testing.T) { Int32: 0, Valid: false, }, - Ip: pqtype.Inet{IPNet: net.IPNet{ + IP: pqtype.Inet{IPNet: net.IPNet{ IP: net.ParseIP("192.168.0.1"), Mask: net.CIDRMask(8, 32), }, Valid: true}, diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 73bef29337..1d0b664b5f 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -784,7 +784,7 @@ func TestIssueSignedAppToken(t *testing.T) { require.NoError(t, err) require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{ - Ip: parsedFakeClientIP, + IP: parsedFakeClientIP, })) }) @@ -812,7 +812,7 @@ func TestIssueSignedAppToken(t *testing.T) { } require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{ - Ip: parsedFakeClientIP, + IP: parsedFakeClientIP, })) }) } @@ -1020,7 +1020,7 @@ func TestReconnectingPTYSignedToken(t *testing.T) { // validate it here. require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{ - Ip: pqtype.Inet{ + IP: pqtype.Inet{ Valid: true, IPNet: net.IPNet{ IP: net.ParseIP("127.0.0.1"), Mask: net.CIDRMask(32, 32),