fix(coderd/workspaceapps): prevent race in workspace app audit session updates (#17020)

Fixes coder/internal#520
This commit is contained in:
Mathias Fredriksson
2025-03-20 16:10:45 +02:00
committed by GitHub
parent 68624092a4
commit 72d9876c76
13 changed files with 68 additions and 32 deletions
+2 -2
View File
@@ -4625,9 +4625,9 @@ func (q *querier) UpsertWorkspaceAgentPortShare(ctx context.Context, arg databas
return q.db.UpsertWorkspaceAgentPortShare(ctx, arg)
}
func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (time.Time, error) {
func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return time.Time{}, err
return false, err
}
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
}
+6 -5
View File
@@ -12298,10 +12298,10 @@ func (q *FakeQuerier) UpsertWorkspaceAgentPortShare(_ context.Context, arg datab
return psl, nil
}
func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (time.Time, error) {
func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
err := validateDatabaseType(arg)
if err != nil {
return time.Time{}, err
return false, err
}
q.mutex.Lock()
@@ -12335,10 +12335,11 @@ func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg data
q.workspaceAppAuditSessions[i].UpdatedAt = arg.UpdatedAt
if !fresh {
q.workspaceAppAuditSessions[i].ID = arg.ID
q.workspaceAppAuditSessions[i].StartedAt = arg.StartedAt
return arg.StartedAt, nil
return true, nil
}
return s.StartedAt, nil
return false, nil
}
q.workspaceAppAuditSessions = append(q.workspaceAppAuditSessions, database.WorkspaceAppAuditSession{
@@ -12352,7 +12353,7 @@ func (q *FakeQuerier) UpsertWorkspaceAppAuditSession(_ context.Context, arg data
StartedAt: arg.StartedAt,
UpdatedAt: arg.UpdatedAt,
})
return arg.StartedAt, nil
return true, nil
}
func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
+1 -1
View File
@@ -2992,7 +2992,7 @@ func (m queryMetricsStore) UpsertWorkspaceAgentPortShare(ctx context.Context, ar
return r0, r1
}
func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (time.Time, error) {
func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
start := time.Now()
r0, r1 := m.s.UpsertWorkspaceAppAuditSession(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertWorkspaceAppAuditSession").Observe(time.Since(start).Seconds())
+2 -2
View File
@@ -6304,10 +6304,10 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAgentPortShare(ctx, arg any) *go
}
// UpsertWorkspaceAppAuditSession mocks base method.
func (m *MockStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (time.Time, error) {
func (m *MockStore) UpsertWorkspaceAppAuditSession(ctx context.Context, arg database.UpsertWorkspaceAppAuditSessionParams) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertWorkspaceAppAuditSession", ctx, arg)
ret0, _ := ret[0].(time.Time)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
+5 -1
View File
@@ -1767,7 +1767,8 @@ CREATE UNLOGGED TABLE workspace_app_audit_sessions (
slug_or_port text NOT NULL,
status_code integer NOT NULL,
started_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL
updated_at timestamp with time zone NOT NULL,
id uuid NOT NULL
);
COMMENT ON TABLE workspace_app_audit_sessions IS 'Audit sessions for workspace apps, the data in this table is ephemeral and is used to deduplicate audit log entries for workspace apps. While a session is active, the same data will not be logged again. This table does not store historical data.';
@@ -2279,6 +2280,9 @@ ALTER TABLE ONLY workspace_agents
ALTER TABLE ONLY workspace_app_audit_sessions
ADD CONSTRAINT workspace_app_audit_sessions_agent_id_app_id_user_id_ip_use_key UNIQUE (agent_id, app_id, user_id, ip, user_agent, slug_or_port, status_code);
ALTER TABLE ONLY workspace_app_audit_sessions
ADD CONSTRAINT workspace_app_audit_sessions_pkey PRIMARY KEY (id);
ALTER TABLE ONLY workspace_app_stats
ADD CONSTRAINT workspace_app_stats_pkey PRIMARY KEY (id);
@@ -0,0 +1,2 @@
ALTER TABLE workspace_app_audit_sessions
DROP COLUMN id;
@@ -0,0 +1,5 @@
-- Add column with default to fix existing rows.
ALTER TABLE workspace_app_audit_sessions
ADD COLUMN id UUID PRIMARY KEY DEFAULT gen_random_uuid();
ALTER TABLE workspace_app_audit_sessions
ALTER COLUMN id DROP DEFAULT;
+1
View File
@@ -3454,6 +3454,7 @@ type WorkspaceAppAuditSession struct {
StartedAt time.Time `db:"started_at" json:"started_at"`
// The time the session was last updated.
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ID uuid.UUID `db:"id" json:"id"`
}
// A record of workspace app usage statistics
+4 -3
View File
@@ -595,9 +595,10 @@ type sqlcQuerier interface {
UpsertTemplateUsageStats(ctx context.Context) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
//
// Insert a new workspace app audit session or update an existing one, if
// started_at is updated, it means the session has been restarted.
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (time.Time, error)
// The returned boolean, new_or_stale, can be used to deduce if a new session
// was started. This means that a new row was inserted (no previous session) or
// the updated_at is older than stale interval.
UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error)
}
var _ sqlcQuerier = (*sqlQuerier)(nil)
+20 -9
View File
@@ -14654,6 +14654,7 @@ func (q *sqlQuerier) InsertWorkspaceAgentStats(ctx context.Context, arg InsertWo
const upsertWorkspaceAppAuditSession = `-- name: UpsertWorkspaceAppAuditSession :one
INSERT INTO
workspace_app_audit_sessions (
id,
agent_id,
app_id,
user_id,
@@ -14674,24 +14675,32 @@ VALUES
$6,
$7,
$8,
$9
$9,
$10
)
ON CONFLICT
(agent_id, app_id, user_id, ip, user_agent, slug_or_port, status_code)
DO
UPDATE
SET
-- ID is used to know if session was reset on upsert.
id = CASE
WHEN workspace_app_audit_sessions.updated_at > NOW() - ($11::bigint || ' ms')::interval
THEN workspace_app_audit_sessions.id
ELSE EXCLUDED.id
END,
started_at = CASE
WHEN workspace_app_audit_sessions.updated_at > NOW() - ($10::bigint || ' ms')::interval
WHEN workspace_app_audit_sessions.updated_at > NOW() - ($11::bigint || ' ms')::interval
THEN workspace_app_audit_sessions.started_at
ELSE EXCLUDED.started_at
END,
updated_at = EXCLUDED.updated_at
RETURNING
started_at
id = $1 AS new_or_stale
`
type UpsertWorkspaceAppAuditSessionParams struct {
ID uuid.UUID `db:"id" json:"id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
AppID uuid.UUID `db:"app_id" json:"app_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
@@ -14704,10 +14713,12 @@ type UpsertWorkspaceAppAuditSessionParams struct {
StaleIntervalMS int64 `db:"stale_interval_ms" json:"stale_interval_ms"`
}
// Insert a new workspace app audit session or update an existing one, if
// started_at is updated, it means the session has been restarted.
func (q *sqlQuerier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (time.Time, error) {
// The returned boolean, new_or_stale, can be used to deduce if a new session
// was started. This means that a new row was inserted (no previous session) or
// the updated_at is older than stale interval.
func (q *sqlQuerier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg UpsertWorkspaceAppAuditSessionParams) (bool, error) {
row := q.db.QueryRowContext(ctx, upsertWorkspaceAppAuditSession,
arg.ID,
arg.AgentID,
arg.AppID,
arg.UserID,
@@ -14719,9 +14730,9 @@ func (q *sqlQuerier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg Ups
arg.UpdatedAt,
arg.StaleIntervalMS,
)
var started_at time.Time
err := row.Scan(&started_at)
return started_at, err
var new_or_stale bool
err := row.Scan(&new_or_stale)
return new_or_stale, err
}
const getWorkspaceAppByAgentIDAndSlug = `-- name: GetWorkspaceAppByAgentIDAndSlug :one
+13 -4
View File
@@ -1,9 +1,11 @@
-- name: UpsertWorkspaceAppAuditSession :one
--
-- Insert a new workspace app audit session or update an existing one, if
-- started_at is updated, it means the session has been restarted.
-- The returned boolean, new_or_stale, can be used to deduce if a new session
-- was started. This means that a new row was inserted (no previous session) or
-- the updated_at is older than stale interval.
INSERT INTO
workspace_app_audit_sessions (
id,
agent_id,
app_id,
user_id,
@@ -24,13 +26,20 @@ VALUES
$6,
$7,
$8,
$9
$9,
$10
)
ON CONFLICT
(agent_id, app_id, user_id, ip, user_agent, slug_or_port, status_code)
DO
UPDATE
SET
-- ID is used to know if session was reset on upsert.
id = CASE
WHEN workspace_app_audit_sessions.updated_at > NOW() - (@stale_interval_ms::bigint || ' ms')::interval
THEN workspace_app_audit_sessions.id
ELSE EXCLUDED.id
END,
started_at = CASE
WHEN workspace_app_audit_sessions.updated_at > NOW() - (@stale_interval_ms::bigint || ' ms')::interval
THEN workspace_app_audit_sessions.started_at
@@ -38,4 +47,4 @@ DO
END,
updated_at = EXCLUDED.updated_at
RETURNING
started_at;
id = $1 AS new_or_stale;
+1
View File
@@ -80,6 +80,7 @@ const (
UniqueWorkspaceAgentVolumeResourceMonitorsPkey UniqueConstraint = "workspace_agent_volume_resource_monitors_pkey" // ALTER TABLE ONLY workspace_agent_volume_resource_monitors ADD CONSTRAINT workspace_agent_volume_resource_monitors_pkey PRIMARY KEY (agent_id, path);
UniqueWorkspaceAgentsPkey UniqueConstraint = "workspace_agents_pkey" // ALTER TABLE ONLY workspace_agents ADD CONSTRAINT workspace_agents_pkey PRIMARY KEY (id);
UniqueWorkspaceAppAuditSessionsAgentIDAppIDUserIDIpUseKey UniqueConstraint = "workspace_app_audit_sessions_agent_id_app_id_user_id_ip_use_key" // ALTER TABLE ONLY workspace_app_audit_sessions ADD CONSTRAINT workspace_app_audit_sessions_agent_id_app_id_user_id_ip_use_key UNIQUE (agent_id, app_id, user_id, ip, user_agent, slug_or_port, status_code);
UniqueWorkspaceAppAuditSessionsPkey UniqueConstraint = "workspace_app_audit_sessions_pkey" // ALTER TABLE ONLY workspace_app_audit_sessions ADD CONSTRAINT workspace_app_audit_sessions_pkey PRIMARY KEY (id);
UniqueWorkspaceAppStatsPkey UniqueConstraint = "workspace_app_stats_pkey" // ALTER TABLE ONLY workspace_app_stats ADD CONSTRAINT workspace_app_stats_pkey PRIMARY KEY (id);
UniqueWorkspaceAppStatsUserIDAgentIDSessionIDKey UniqueConstraint = "workspace_app_stats_user_id_agent_id_session_id_key" // ALTER TABLE ONLY workspace_app_stats ADD CONSTRAINT workspace_app_stats_user_id_agent_id_session_id_key UNIQUE (user_id, agent_id, session_id);
UniqueWorkspaceAppsAgentIDSlugIndex UniqueConstraint = "workspace_apps_agent_id_slug_idx" // ALTER TABLE ONLY workspace_apps ADD CONSTRAINT workspace_apps_agent_id_slug_idx UNIQUE (agent_id, slug);
+6 -5
View File
@@ -447,16 +447,17 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
slog.F("status_code", statusCode),
)
var startedAt time.Time
var newOrStale bool
err := p.Database.InTx(func(tx database.Store) (err error) {
// nolint:gocritic // System context is needed to write audit sessions.
dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx)
startedAt, err = tx.UpsertWorkspaceAppAuditSession(dangerousSystemCtx, database.UpsertWorkspaceAppAuditSessionParams{
newOrStale, err = tx.UpsertWorkspaceAppAuditSession(dangerousSystemCtx, database.UpsertWorkspaceAppAuditSessionParams{
// Config.
StaleIntervalMS: p.WorkspaceAppAuditSessionTimeout.Milliseconds(),
// Data.
ID: uuid.New(),
AgentID: aReq.dbReq.Agent.ID,
AppID: aReq.dbReq.App.ID, // Can be unset, in which case uuid.Nil is fine.
UserID: userID, // Can be unset, in which case uuid.Nil is fine.
@@ -481,9 +482,9 @@ func (p *DBTokenProvider) auditInitRequest(ctx context.Context, w http.ResponseW
return
}
if !startedAt.Equal(aReq.time) {
// If the unique session wasn't renewed, we don't want to log a new
// audit event for it.
if !newOrStale {
// We either didn't insert a new session, or the session
// didn't timeout due to inactivity.
return
}