diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 7cf4bf227b..a724c5959d 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1853,9 +1853,9 @@ func (q *querier) DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.U return q.db.DeleteAllChatQueuedMessages(ctx, chatID) } -func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error { +func (q *querier) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) ([]database.DeleteAllTailnetTunnelsRow, error) { if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { - return err + return nil, err } return q.db.DeleteAllTailnetTunnels(ctx, arg) } @@ -6676,9 +6676,9 @@ func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaP return q.db.UpdateReplica(ctx, arg) } -func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { +func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) ([]uuid.UUID, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { - return err + return nil, err } return q.db.UpdateTailnetPeerStatusByCoordinator(ctx, arg) } diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 58dbfa3a87..35b5f13815 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -400,12 +400,12 @@ func (m queryMetricsStore) DeleteAllChatQueuedMessages(ctx context.Context, chat return r0 } -func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error { +func (m queryMetricsStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) ([]database.DeleteAllTailnetTunnelsRow, error) { start := time.Now() - r0 := m.s.DeleteAllTailnetTunnels(ctx, arg) + r0, r1 := m.s.DeleteAllTailnetTunnels(ctx, arg) m.queryLatencies.WithLabelValues("DeleteAllTailnetTunnels").Observe(time.Since(start).Seconds()) m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteAllTailnetTunnels").Inc() - return r0 + return r0, r1 } func (m queryMetricsStore) DeleteAllWebpushSubscriptions(ctx context.Context) error { @@ -4784,12 +4784,12 @@ func (m queryMetricsStore) UpdateReplica(ctx context.Context, arg database.Updat return r0, r1 } -func (m queryMetricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { +func (m queryMetricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) ([]uuid.UUID, error) { start := time.Now() - r0 := m.s.UpdateTailnetPeerStatusByCoordinator(ctx, arg) + r0, r1 := m.s.UpdateTailnetPeerStatusByCoordinator(ctx, arg) m.queryLatencies.WithLabelValues("UpdateTailnetPeerStatusByCoordinator").Observe(time.Since(start).Seconds()) m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateTailnetPeerStatusByCoordinator").Inc() - return r0 + return r0, r1 } func (m queryMetricsStore) UpdateTaskPrompt(ctx context.Context, arg database.UpdateTaskPromptParams) (database.TaskTable, error) { diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index baa9a4ab93..45c0d4a97c 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -645,11 +645,12 @@ func (mr *MockStoreMockRecorder) DeleteAllChatQueuedMessages(ctx, chatID any) *g } // DeleteAllTailnetTunnels mocks base method. -func (m *MockStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) error { +func (m *MockStore) DeleteAllTailnetTunnels(ctx context.Context, arg database.DeleteAllTailnetTunnelsParams) ([]database.DeleteAllTailnetTunnelsRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteAllTailnetTunnels", ctx, arg) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]database.DeleteAllTailnetTunnelsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 } // DeleteAllTailnetTunnels indicates an expected call of DeleteAllTailnetTunnels. @@ -9033,11 +9034,12 @@ func (mr *MockStoreMockRecorder) UpdateReplica(ctx, arg any) *gomock.Call { } // UpdateTailnetPeerStatusByCoordinator mocks base method. -func (m *MockStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) error { +func (m *MockStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg database.UpdateTailnetPeerStatusByCoordinatorParams) ([]uuid.UUID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateTailnetPeerStatusByCoordinator", ctx, arg) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 } // UpdateTailnetPeerStatusByCoordinator indicates an expected call of UpdateTailnetPeerStatusByCoordinator. diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 6a92612c1a..01ed07de6b 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1061,44 +1061,6 @@ BEGIN END; $$; -CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger - LANGUAGE plpgsql - AS $$ -BEGIN - PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); - RETURN NULL; -END; -$$; - -CREATE FUNCTION tailnet_notify_peer_change() RETURNS trigger - LANGUAGE plpgsql - AS $$ -BEGIN - IF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_peer_update', OLD.id::text); - RETURN NULL; - END IF; - IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_peer_update', NEW.id::text); - RETURN NULL; - END IF; -END; -$$; - -CREATE FUNCTION tailnet_notify_tunnel_change() RETURNS trigger - LANGUAGE plpgsql - AS $$ -BEGIN - IF (NEW IS NOT NULL) THEN - PERFORM pg_notify('tailnet_tunnel_update', NEW.src_id || ',' || NEW.dst_id); - RETURN NULL; - ELSIF (OLD IS NOT NULL) THEN - PERFORM pg_notify('tailnet_tunnel_update', OLD.src_id || ',' || OLD.dst_id); - RETURN NULL; - END IF; -END; -$$; - CREATE TABLE ai_seat_state ( user_id uuid NOT NULL, first_used_at timestamp with time zone NOT NULL, @@ -4101,12 +4063,6 @@ CREATE TRIGGER remove_organization_member_custom_role BEFORE DELETE ON custom_ro COMMENT ON TRIGGER remove_organization_member_custom_role ON custom_roles IS 'When a custom_role is deleted, this trigger removes the role from all organization members.'; -CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); - -CREATE TRIGGER tailnet_notify_peer_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_peers FOR EACH ROW EXECUTE FUNCTION tailnet_notify_peer_change(); - -CREATE TRIGGER tailnet_notify_tunnel_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_tunnels FOR EACH ROW EXECUTE FUNCTION tailnet_notify_tunnel_change(); - CREATE TRIGGER trigger_aggregate_usage_event AFTER INSERT ON usage_events FOR EACH ROW EXECUTE FUNCTION aggregate_usage_event(); CREATE TRIGGER trigger_delete_group_members_on_org_member_delete BEFORE DELETE ON organization_members FOR EACH ROW EXECUTE FUNCTION delete_group_members_on_org_member_delete(); diff --git a/coderd/database/migrations/000483_drop_tailnet_notify_triggers.down.sql b/coderd/database/migrations/000483_drop_tailnet_notify_triggers.down.sql new file mode 100644 index 0000000000..ea0117340f --- /dev/null +++ b/coderd/database/migrations/000483_drop_tailnet_notify_triggers.down.sql @@ -0,0 +1,43 @@ +CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); + RETURN NULL; +END; +$$; + +CREATE FUNCTION tailnet_notify_peer_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_peer_update', OLD.id::text); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_peer_update', NEW.id::text); + RETURN NULL; + END IF; +END; +$$; + +CREATE FUNCTION tailnet_notify_tunnel_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_tunnel_update', NEW.src_id || ',' || NEW.dst_id); + RETURN NULL; + ELSIF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_tunnel_update', OLD.src_id || ',' || OLD.dst_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); + +CREATE TRIGGER tailnet_notify_peer_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_peers FOR EACH ROW EXECUTE FUNCTION tailnet_notify_peer_change(); + +CREATE TRIGGER tailnet_notify_tunnel_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_tunnels FOR EACH ROW EXECUTE FUNCTION tailnet_notify_tunnel_change(); diff --git a/coderd/database/migrations/000483_drop_tailnet_notify_triggers.up.sql b/coderd/database/migrations/000483_drop_tailnet_notify_triggers.up.sql new file mode 100644 index 0000000000..937a0c8ffd --- /dev/null +++ b/coderd/database/migrations/000483_drop_tailnet_notify_triggers.up.sql @@ -0,0 +1,6 @@ +DROP TRIGGER IF EXISTS tailnet_notify_peer_change ON tailnet_peers; +DROP TRIGGER IF EXISTS tailnet_notify_tunnel_change ON tailnet_tunnels; +DROP TRIGGER IF EXISTS tailnet_notify_coordinator_heartbeat ON tailnet_coordinators; +DROP FUNCTION IF EXISTS tailnet_notify_peer_change(); +DROP FUNCTION IF EXISTS tailnet_notify_tunnel_change(); +DROP FUNCTION IF EXISTS tailnet_notify_coordinator_heartbeat(); diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cd28f742fd..85749324c2 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -101,7 +101,7 @@ type sqlcQuerier interface { DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteAllChatQueuedMessages(ctx context.Context, chatID uuid.UUID) error - DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error + DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) ([]DeleteAllTailnetTunnelsRow, error) // Deletes all existing webpush subscriptions. // This should be called when the VAPID keypair is regenerated, as the old // keypair will no longer be valid and all existing subscriptions will need to @@ -1117,7 +1117,7 @@ type sqlcQuerier interface { UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateProvisionerJobWithCompleteWithStartedAtByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteWithStartedAtByIDParams) error UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) - UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error + UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) ([]uuid.UUID, error) UpdateTaskPrompt(ctx context.Context, arg UpdateTaskPromptParams) (TaskTable, error) UpdateTaskWorkspaceID(ctx context.Context, arg UpdateTaskWorkspaceIDParams) (TaskTable, error) UpdateTemplateACLByID(ctx context.Context, arg UpdateTemplateACLByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e3f652c8eb..ac061d90b5 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -21263,10 +21263,11 @@ func (q *sqlQuerier) CleanTailnetTunnels(ctx context.Context) error { return err } -const deleteAllTailnetTunnels = `-- name: DeleteAllTailnetTunnels :exec +const deleteAllTailnetTunnels = `-- name: DeleteAllTailnetTunnels :many DELETE FROM tailnet_tunnels WHERE coordinator_id = $1 and src_id = $2 +RETURNING src_id, dst_id ` type DeleteAllTailnetTunnelsParams struct { @@ -21274,9 +21275,32 @@ type DeleteAllTailnetTunnelsParams struct { SrcID uuid.UUID `db:"src_id" json:"src_id"` } -func (q *sqlQuerier) DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) error { - _, err := q.db.ExecContext(ctx, deleteAllTailnetTunnels, arg.CoordinatorID, arg.SrcID) - return err +type DeleteAllTailnetTunnelsRow struct { + SrcID uuid.UUID `db:"src_id" json:"src_id"` + DstID uuid.UUID `db:"dst_id" json:"dst_id"` +} + +func (q *sqlQuerier) DeleteAllTailnetTunnels(ctx context.Context, arg DeleteAllTailnetTunnelsParams) ([]DeleteAllTailnetTunnelsRow, error) { + rows, err := q.db.QueryContext(ctx, deleteAllTailnetTunnels, arg.CoordinatorID, arg.SrcID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []DeleteAllTailnetTunnelsRow + for rows.Next() { + var i DeleteAllTailnetTunnelsRow + if err := rows.Scan(&i.SrcID, &i.DstID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const deleteTailnetPeer = `-- name: DeleteTailnetPeer :one @@ -21551,13 +21575,14 @@ func (q *sqlQuerier) GetTailnetTunnelPeerIDsBatch(ctx context.Context, ids []uui return items, nil } -const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :exec +const updateTailnetPeerStatusByCoordinator = `-- name: UpdateTailnetPeerStatusByCoordinator :many UPDATE tailnet_peers SET status = $2 WHERE coordinator_id = $1 +RETURNING id ` type UpdateTailnetPeerStatusByCoordinatorParams struct { @@ -21565,9 +21590,27 @@ type UpdateTailnetPeerStatusByCoordinatorParams struct { Status TailnetStatus `db:"status" json:"status"` } -func (q *sqlQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) error { - _, err := q.db.ExecContext(ctx, updateTailnetPeerStatusByCoordinator, arg.CoordinatorID, arg.Status) - return err +func (q *sqlQuerier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg UpdateTailnetPeerStatusByCoordinatorParams) ([]uuid.UUID, error) { + rows, err := q.db.QueryContext(ctx, updateTailnetPeerStatusByCoordinator, arg.CoordinatorID, arg.Status) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uuid.UUID + for rows.Next() { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index 4620a31e6c..ce7cad98d6 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -50,13 +50,14 @@ DO UPDATE SET updated_at = now() at time zone 'utc' RETURNING *; --- name: UpdateTailnetPeerStatusByCoordinator :exec +-- name: UpdateTailnetPeerStatusByCoordinator :many UPDATE tailnet_peers SET status = $2 WHERE - coordinator_id = $1; + coordinator_id = $1 +RETURNING id; -- name: DeleteTailnetPeer :one DELETE @@ -91,10 +92,11 @@ FROM tailnet_tunnels WHERE coordinator_id = $1 and src_id = $2 and dst_id = $3 RETURNING coordinator_id, src_id, dst_id; --- name: DeleteAllTailnetTunnels :exec +-- name: DeleteAllTailnetTunnels :many DELETE FROM tailnet_tunnels -WHERE coordinator_id = $1 and src_id = $2; +WHERE coordinator_id = $1 and src_id = $2 +RETURNING src_id, dst_id; -- For PG Coordinator HTMLDebug diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 81dc970f13..6f8abd701c 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -42,6 +42,27 @@ const ( CloseErrUnhealthy = "coordinator unhealthy" ) +func publishPeerUpdate(ctx context.Context, ps pubsub.Pubsub, logger slog.Logger, peerID uuid.UUID) { + if err := ps.Publish(eventPeerUpdate, []byte(peerID.String())); err != nil { + logger.Warn(ctx, "failed to publish peer update", slog.F("peer_id", peerID), slog.Error(err)) + } +} + +func publishTunnelUpdate(ctx context.Context, ps pubsub.Pubsub, logger slog.Logger, srcID, dstID uuid.UUID) { + if err := ps.Publish(eventTunnelUpdate, []byte(srcID.String()+","+dstID.String())); err != nil { + logger.Warn(ctx, "failed to publish tunnel update", + slog.F("src_id", srcID), slog.F("dst_id", dstID), slog.Error(err)) + } +} + +func publishCoordinatorHeartbeat(ctx context.Context, ps pubsub.Pubsub, logger slog.Logger, id uuid.UUID) { + if err := ps.Publish(EventHeartbeats, []byte(id.String())); err != nil { + logger.Warn(ctx, "failed to publish coordinator heartbeat", slog.F("coordinator_id", id), slog.Error(err)) + } else { + logger.Debug(ctx, "sent heartbeat", slog.F("coordinator_id", id)) + } +} + // pgCoord is a postgres-backed coordinator // // ┌────────────┐ @@ -152,11 +173,11 @@ func newPGCoordInternal( logger: logger, pubsub: ps, store: store, - binder: newBinder(ctx, logger, id, store, bCh, fHB), + binder: newBinder(ctx, logger, id, store, ps, bCh, fHB), bindings: bCh, newConnections: cCh, closeConnections: ccCh, - tunneler: newTunneler(ctx, logger, id, store, sCh, fHB), + tunneler: newTunneler(ctx, logger, id, store, ps, sCh, fHB), tunnelerCh: sCh, handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB), handshakerCh: rfhCh, @@ -273,6 +294,7 @@ type tunneler struct { logger slog.Logger coordinatorID uuid.UUID store database.Store + pubsub pubsub.Pubsub updates <-chan tunnel mu sync.Mutex @@ -286,6 +308,7 @@ func newTunneler(ctx context.Context, logger slog.Logger, id uuid.UUID, store database.Store, + ps pubsub.Pubsub, updates <-chan tunnel, startWorkers <-chan struct{}, ) *tunneler { @@ -294,6 +317,7 @@ func newTunneler(ctx context.Context, logger: logger, coordinatorID: id, store: store, + pubsub: ps, updates: updates, latest: make(map[uuid.UUID]map[uuid.UUID]tunnel), workQ: newWorkQ[tKey](ctx), @@ -396,7 +420,8 @@ func (t *tunneler) writeOne(tun tunnel) error { var err error switch { case tun.dst == uuid.Nil: - err = t.store.DeleteAllTailnetTunnels(t.ctx, database.DeleteAllTailnetTunnelsParams{ + var deleted []database.DeleteAllTailnetTunnelsRow + deleted, err = t.store.DeleteAllTailnetTunnels(t.ctx, database.DeleteAllTailnetTunnelsParams{ SrcID: tun.src, CoordinatorID: t.coordinatorID, }) @@ -404,6 +429,11 @@ func (t *tunneler) writeOne(tun tunnel) error { slog.F("src_id", tun.src), slog.Error(err), ) + if err == nil { + for _, row := range deleted { + publishTunnelUpdate(t.ctx, t.pubsub, t.logger, row.SrcID, row.DstID) + } + } case tun.active: _, err = t.store.UpsertTailnetTunnel(t.ctx, database.UpsertTailnetTunnelParams{ CoordinatorID: t.coordinatorID, @@ -415,6 +445,9 @@ func (t *tunneler) writeOne(tun tunnel) error { slog.F("dst_id", tun.dst), slog.Error(err), ) + if err == nil { + publishTunnelUpdate(t.ctx, t.pubsub, t.logger, tun.src, tun.dst) + } case !tun.active: _, err = t.store.DeleteTailnetTunnel(t.ctx, database.DeleteTailnetTunnelParams{ CoordinatorID: t.coordinatorID, @@ -428,7 +461,10 @@ func (t *tunneler) writeOne(tun tunnel) error { ) // writeOne should be idempotent if xerrors.Is(err, sql.ErrNoRows) { - err = nil + return nil // No row deleted, skip publish. + } + if err == nil { + publishTunnelUpdate(t.ctx, t.pubsub, t.logger, tun.src, tun.dst) } default: panic("unreachable") @@ -459,6 +495,7 @@ type binder struct { logger slog.Logger coordinatorID uuid.UUID store database.Store + pubsub pubsub.Pubsub bindings <-chan binding mu sync.Mutex @@ -473,6 +510,7 @@ func newBinder(ctx context.Context, logger slog.Logger, id uuid.UUID, store database.Store, + ps pubsub.Pubsub, bindings <-chan binding, startWorkers <-chan struct{}, ) *binder { @@ -481,6 +519,7 @@ func newBinder(ctx context.Context, logger: logger, coordinatorID: id, store: store, + pubsub: ps, bindings: bindings, latest: make(map[bKey]binding), workQ: newWorkQ[bKey](ctx), @@ -508,13 +547,16 @@ func newBinder(ctx context.Context, ctx, cancel := context.WithTimeout(dbauthz.As(context.Background(), pgCoordSubject), time.Second*15) defer cancel() - err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ + peerIDs, err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ CoordinatorID: b.coordinatorID, Status: database.TailnetStatusLost, }) if err != nil { b.logger.Error(b.ctx, "update peer status to lost", slog.Error(err)) } + for _, peerID := range peerIDs { + publishPeerUpdate(ctx, b.pubsub, b.logger, peerID) + } }() return b } @@ -593,6 +635,9 @@ func (b *binder) writeOne(bnd binding) error { slog.F("node", bnd.node), slog.Error(err)) } + if err == nil { + publishPeerUpdate(b.ctx, b.pubsub, b.logger, uuid.UUID(bnd.bKey)) + } return err } @@ -1299,9 +1344,11 @@ func (q *querier) listenReadyForHandshake(_ context.Context, msg []byte, err err func (q *querier) resyncPeerMappings() { q.mu.Lock() defer q.mu.Unlock() + keys := make([]mKey, 0, len(q.mappers)) for mk := range q.mappers { - q.mappingQ.enqueue(mk) + keys = append(keys, mk) } + q.mappingQ.enqueue(keys...) } func (q *querier) handleUpdates() { @@ -1710,11 +1757,17 @@ func (h *heartbeats) checkExpiry() { expired := false for id, t := range h.coordinators { lastHB := now.Sub(t) - h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) + h.logger.Debug(h.ctx, "last heartbeat from coordinator", + slog.F("other_coordinator_id", id), + slog.F("last_heartbeat", lastHB), + ) if lastHB >= MissedHeartbeats*HeartbeatPeriod { expired = true delete(h.coordinators, id) - h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) + h.logger.Info(h.ctx, "coordinator failed heartbeat check", + slog.F("other_coordinator_id", id), + slog.F("last_heartbeat", lastHB), + ) } } if expired { @@ -1754,7 +1807,7 @@ func (h *heartbeats) sendBeat() { } return } - h.logger.Debug(h.ctx, "sent heartbeat") + publishCoordinatorHeartbeat(h.ctx, h.pubsub, h.logger, h.self) if h.failedHeartbeats >= 3 { h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy") _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) diff --git a/enterprise/tailnet/pgcoord_internal_test.go b/enterprise/tailnet/pgcoord_internal_test.go index 3c9ad786f7..975e499278 100644 --- a/enterprise/tailnet/pgcoord_internal_test.go +++ b/enterprise/tailnet/pgcoord_internal_test.go @@ -76,6 +76,8 @@ func TestHeartbeats_recvBeat_resetSkew(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) logger := testutil.Logger(t) + ctrl := gomock.NewController(t) + mStore := dbmock.NewMockStore(ctrl) mClock := quartz.NewMock(t) trap := mClock.Trap().Until("heartbeats", "resetExpiryTimerWithLock") defer trap.Close() @@ -83,12 +85,12 @@ func TestHeartbeats_recvBeat_resetSkew(t *testing.T) { uut := heartbeats{ ctx: ctx, logger: logger, + store: mStore, clock: mClock, self: uuid.UUID{1}, update: make(chan hbUpdate, 4), coordinators: make(map[uuid.UUID]time.Time), } - coord2 := uuid.UUID{2} coord3 := uuid.UUID{3} @@ -397,7 +399,7 @@ func TestPGCoordinatorUnhealthy(t *testing.T) { mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) - mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) + mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()).Return(nil, nil) coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock) require.NoError(t, err) diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index ccb1fe2016..3ec874ad17 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -268,6 +268,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { ctx: ctx, t: t, store: store, + ps: ps, id: uuid.New(), } @@ -281,6 +282,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { ctx: ctx, t: t, store: store, + ps: ps, id: uuid.New(), } fCoord3.heartbeat() @@ -304,7 +306,6 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { // one more heartbeat period will result in fCoord2 being expired, which should cause us to // revert to the original agent mapping mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) - // note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired client.AssertEventuallyHasDERP(agent.ID, 10) // send fCoord3 heartbeat, which should trigger us to consider that mapping valid again. @@ -343,6 +344,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) { ctx: ctx, t: t, store: store, + ps: ps, id: uuid.New(), } // simulate a single heartbeat, the coordinator is healthy @@ -594,7 +596,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { mStore.EXPECT().GetTailnetTunnelPeerBindingsBatch(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) - mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) @@ -948,6 +950,7 @@ type fakeCoordinator struct { ctx context.Context t *testing.T store database.Store + ps pubsub.Pubsub id uuid.UUID } @@ -955,6 +958,8 @@ func (c *fakeCoordinator) heartbeat() { c.t.Helper() _, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id) require.NoError(c.t, err) + err = c.ps.Publish(tailnet.EventHeartbeats, []byte(c.id.String())) + require.NoError(c.t, err) } func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { @@ -970,4 +975,6 @@ func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { Status: database.TailnetStatusOk, }) require.NoError(c.t, err) + err = c.ps.Publish("tailnet_peer_update", []byte(agentID.String())) + require.NoError(c.t, err) }