Files
coder/enterprise/tailnet/pgcoord_internal_test.go
T
Hugo Dutka e62c5db678 chore: remove references to dbtestutil.WillUsePostgres (#20436)
Addresses https://github.com/coder/internal/issues/758.

This PR only cleans up dead code, it makes no changes to test logic.
2025-10-23 14:24:54 +02:00

437 lines
13 KiB
Go

package tailnet
import (
"bytes"
"context"
"flag"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/quartz"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
// UpdateGoldenFiles indicates golden files should be updated.
// To update the golden files:
// make gen/golden-files
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
// TestHeartbeats_Cleanup tests the cleanup loop
func TestHeartbeats_Cleanup(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := testutil.Logger(t)
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).Times(2).Return(nil)
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).Times(2).Return(nil)
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).Times(2).Return(nil)
mClock := quartz.NewMock(t)
trap := mClock.Trap().TickerFunc("heartbeats", "cleanupLoop")
defer trap.Close()
uut := &heartbeats{
ctx: ctx,
logger: logger,
store: mStore,
clock: mClock,
}
uut.wg.Add(1)
go uut.cleanupLoop()
call := trap.MustWait(ctx)
call.MustRelease(ctx)
require.Equal(t, cleanupPeriod, call.Duration)
mClock.Advance(cleanupPeriod).MustWait(ctx)
}
// TestHeartbeats_recvBeat_resetSkew is a regression test for a bug where heartbeats from two
// coordinators slightly skewed from one another could result in one coordinator failing to get
// expired
func TestHeartbeats_recvBeat_resetSkew(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
mClock := quartz.NewMock(t)
trap := mClock.Trap().Until("heartbeats", "resetExpiryTimerWithLock")
defer trap.Close()
uut := heartbeats{
ctx: ctx,
logger: logger,
clock: mClock,
self: uuid.UUID{1},
update: make(chan hbUpdate, 4),
coordinators: make(map[uuid.UUID]time.Time),
}
coord2 := uuid.UUID{2}
coord3 := uuid.UUID{3}
uut.listen(ctx, []byte(coord2.String()), nil)
// coord 3 heartbeat comes very soon after
mClock.Advance(time.Millisecond).MustWait(ctx)
go uut.listen(ctx, []byte(coord3.String()), nil)
trap.MustWait(ctx).MustRelease(ctx)
// both coordinators are present
uut.lock.RLock()
require.Contains(t, uut.coordinators, coord2)
require.Contains(t, uut.coordinators, coord3)
uut.lock.RUnlock()
// no more heartbeats arrive, and coord2 expires
w := mClock.Advance(MissedHeartbeats*HeartbeatPeriod - time.Millisecond)
// however, several ms pass between expiring 2 and computing the time until 3 expires
c := trap.MustWait(ctx)
mClock.Advance(2 * time.Millisecond).MustWait(ctx) // 3 has now expired _in the past_
c.MustRelease(ctx)
w.MustWait(ctx)
// expired in the past means we immediately reschedule checkExpiry, so we get another call
trap.MustWait(ctx).MustRelease(ctx)
uut.lock.RLock()
require.NotContains(t, uut.coordinators, coord2)
require.NotContains(t, uut.coordinators, coord3)
uut.lock.RUnlock()
}
func TestHeartbeats_LostCoordinator_MarkLost(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
mClock := quartz.NewMock(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := testutil.Logger(t)
uut := &heartbeats{
ctx: ctx,
logger: logger,
store: mStore,
coordinators: map[uuid.UUID]time.Time{
uuid.New(): mClock.Now(),
},
clock: mClock,
}
mpngs := []mapping{{
peer: uuid.New(),
coordinator: uuid.New(),
updatedAt: mClock.Now(),
node: &proto.Node{},
kind: proto.CoordinateResponse_PeerUpdate_NODE,
}}
// Filter should still return the mapping without a coordinator, but marked
// as LOST.
got := uut.filter(mpngs)
require.Len(t, got, 1)
assert.Equal(t, proto.CoordinateResponse_PeerUpdate_LOST, got[0].kind)
}
// TestLostPeerCleanupQueries tests that our SQL queries to clean up lost peers do what we expect,
// that is, clean up peers and associated tunnels that have been lost for over 24 hours.
func TestLostPeerCleanupQueries(t *testing.T) {
t.Parallel()
store, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coordID := uuid.New()
_, err := store.UpsertTailnetCoordinator(ctx, coordID)
require.NoError(t, err)
peerID := uuid.New()
_, err = store.UpsertTailnetPeer(ctx, database.UpsertTailnetPeerParams{
ID: peerID,
CoordinatorID: coordID,
Node: []byte("test"),
Status: database.TailnetStatusLost,
})
require.NoError(t, err)
otherID := uuid.New()
_, err = store.UpsertTailnetTunnel(ctx, database.UpsertTailnetTunnelParams{
CoordinatorID: coordID,
SrcID: peerID,
DstID: otherID,
})
require.NoError(t, err)
peers, err := store.GetAllTailnetPeers(ctx)
require.NoError(t, err)
require.Len(t, peers, 1)
require.Equal(t, peerID, peers[0].ID)
tunnels, err := store.GetAllTailnetTunnels(ctx)
require.NoError(t, err)
require.Len(t, tunnels, 1)
require.Equal(t, peerID, tunnels[0].SrcID)
require.Equal(t, otherID, tunnels[0].DstID)
// this clean is a noop since the peer and tunnel are less than 24h old
err = store.CleanTailnetLostPeers(ctx)
require.NoError(t, err)
err = store.CleanTailnetTunnels(ctx)
require.NoError(t, err)
peers, err = store.GetAllTailnetPeers(ctx)
require.NoError(t, err)
require.Len(t, peers, 1)
require.Equal(t, peerID, peers[0].ID)
tunnels, err = store.GetAllTailnetTunnels(ctx)
require.NoError(t, err)
require.Len(t, tunnels, 1)
require.Equal(t, peerID, tunnels[0].SrcID)
require.Equal(t, otherID, tunnels[0].DstID)
// set the age of the tunnel to >24h
sqlDB.Exec("UPDATE tailnet_tunnels SET updated_at = $1", time.Now().Add(-25*time.Hour))
// this clean is still a noop since the peer hasn't been lost for 24 hours
err = store.CleanTailnetLostPeers(ctx)
require.NoError(t, err)
err = store.CleanTailnetTunnels(ctx)
require.NoError(t, err)
peers, err = store.GetAllTailnetPeers(ctx)
require.NoError(t, err)
require.Len(t, peers, 1)
require.Equal(t, peerID, peers[0].ID)
tunnels, err = store.GetAllTailnetTunnels(ctx)
require.NoError(t, err)
require.Len(t, tunnels, 1)
require.Equal(t, peerID, tunnels[0].SrcID)
require.Equal(t, otherID, tunnels[0].DstID)
// set the age of the tunnel to >24h
sqlDB.Exec("UPDATE tailnet_peers SET updated_at = $1", time.Now().Add(-25*time.Hour))
// this clean removes the peer and the associated tunnel
err = store.CleanTailnetLostPeers(ctx)
require.NoError(t, err)
err = store.CleanTailnetTunnels(ctx)
require.NoError(t, err)
peers, err = store.GetAllTailnetPeers(ctx)
require.NoError(t, err)
require.Len(t, peers, 0)
tunnels, err = store.GetAllTailnetTunnels(ctx)
require.NoError(t, err)
require.Len(t, tunnels, 0)
}
func TestDebugTemplate(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("newlines screw up golden files on windows")
}
c1 := uuid.MustParse("01000000-1111-1111-1111-111111111111")
c2 := uuid.MustParse("02000000-1111-1111-1111-111111111111")
p1 := uuid.MustParse("01000000-2222-2222-2222-222222222222")
p2 := uuid.MustParse("02000000-2222-2222-2222-222222222222")
in := HTMLDebug{
Coordinators: []*HTMLCoordinator{
{
ID: c1,
HeartbeatAge: 2 * time.Second,
},
{
ID: c2,
HeartbeatAge: time.Second,
},
},
Peers: []*HTMLPeer{
{
ID: p1,
CoordinatorID: c1,
LastWriteAge: 5 * time.Second,
Status: database.TailnetStatusOk,
Node: `id:1 preferred_derp:999 endpoints:"192.168.0.49:4449"`,
},
{
ID: p2,
CoordinatorID: c2,
LastWriteAge: 7 * time.Second,
Status: database.TailnetStatusLost,
Node: `id:2 preferred_derp:999 endpoints:"192.168.0.33:4449"`,
},
},
Tunnels: []*HTMLTunnel{
{
CoordinatorID: c1,
SrcID: p1,
DstID: p2,
LastWriteAge: 3 * time.Second,
},
},
}
buf := new(bytes.Buffer)
err := debugTempl.Execute(buf, in)
require.NoError(t, err)
actual := buf.Bytes()
goldenPath := filepath.Join("testdata", "debug.golden.html")
if *UpdateGoldenFiles {
t.Logf("update golden file %s", goldenPath)
err := os.WriteFile(goldenPath, actual, 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make gen/golden-files\", verify and commit the changes",
goldenPath,
)
}
func TestGetDebug(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coordID := uuid.New()
_, err := store.UpsertTailnetCoordinator(ctx, coordID)
require.NoError(t, err)
peerID := uuid.New()
node := &proto.Node{PreferredDerp: 44}
nodeb, err := gProto.Marshal(node)
require.NoError(t, err)
_, err = store.UpsertTailnetPeer(ctx, database.UpsertTailnetPeerParams{
ID: peerID,
CoordinatorID: coordID,
Node: nodeb,
Status: database.TailnetStatusLost,
})
require.NoError(t, err)
dstID := uuid.New()
_, err = store.UpsertTailnetTunnel(ctx, database.UpsertTailnetTunnelParams{
CoordinatorID: coordID,
SrcID: peerID,
DstID: dstID,
})
require.NoError(t, err)
debug, err := getDebug(ctx, store)
require.NoError(t, err)
require.Len(t, debug.Coordinators, 1)
require.Len(t, debug.Peers, 1)
require.Len(t, debug.Tunnels, 1)
require.Equal(t, coordID, debug.Coordinators[0].ID)
require.Equal(t, peerID, debug.Peers[0].ID)
require.Equal(t, coordID, debug.Peers[0].CoordinatorID)
require.Equal(t, database.TailnetStatusLost, debug.Peers[0].Status)
require.Equal(t, node.String(), debug.Peers[0].Node)
require.Equal(t, coordID, debug.Tunnels[0].CoordinatorID)
require.Equal(t, peerID, debug.Tunnels[0].SrcID)
require.Equal(t, dstID, debug.Tunnels[0].DstID)
}
// TestPGCoordinatorUnhealthy tests that when the coordinator fails to send heartbeats and is
// unhealthy it disconnects any peers and does not send any extraneous database queries.
func TestPGCoordinatorUnhealthy(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ps := pubsub.NewInMemory()
mClock := quartz.NewMock(t)
tfTrap := mClock.Trap().TickerFunc("heartbeats", "sendBeats")
defer tfTrap.Close()
// after 3 failed heartbeats, the coordinator is unhealthy
mStore.EXPECT().
UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
Times(3).
Return(database.TailnetCoordinator{}, xerrors.New("badness"))
// But, in particular we DO NOT want the coordinator to call DeleteTailnetPeer, as this is
// unnecessary and can spam the database. c.f. https://github.com/coder/coder/issues/12923
// these cleanup queries run, but we don't care for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any())
coordinator, err := newPGCoordInternal(ctx, logger, ps, mStore, mClock)
require.NoError(t, err)
expectedPeriod := HeartbeatPeriod
tfCall, err := tfTrap.Wait(ctx)
require.NoError(t, err)
tfCall.MustRelease(ctx)
require.Equal(t, expectedPeriod, tfCall.Duration)
// Now that the ticker has started, we can advance 2 more beats to get to 3
// failed heartbeats
mClock.Advance(HeartbeatPeriod).MustWait(ctx)
mClock.Advance(HeartbeatPeriod).MustWait(ctx)
// The querier is informed async about being unhealthy, so we need to wait
// until it is.
require.Eventually(t, func() bool {
return !coordinator.querier.isHealthy()
}, testutil.WaitShort, testutil.IntervalFast)
pID := uuid.UUID{5}
_, resps := coordinator.Coordinate(ctx, pID, "test", agpl.AgentCoordinateeAuth{ID: pID})
resp := testutil.RequireReceive(ctx, t, resps)
require.Equal(t, CloseErrUnhealthy, resp.Error)
resp = testutil.TryReceive(ctx, t, resps)
require.Nil(t, resp, "channel should be closed")
// give the coordinator some time to process any pending work. We are
// testing here that a database call is absent, so we don't want to race to
// shut down the test.
time.Sleep(testutil.IntervalMedium)
_ = coordinator.Close()
require.Eventually(t, ctrl.Satisfied, testutil.WaitShort, testutil.IntervalFast)
}