Files
coder/enterprise/tailnet/pgcoord_test.go
T
Spike Curtis 0cab6e7763 feat: support graceful disconnect in PGCoordinator (#10937)
Adds support for graceful disconnect to PGCoordinator.  When peers gracefully disconnect, they send a disconnect message.  This triggers the peer to be disconnected from all tunneled peers.

The Multi-Agent Client supports graceful disconnect, since it is in memory and we know that when it is closed, we really mean to disconnect.

The v1 agent and client Websocket connections do not support graceful disconnect, since the v1 protocol doesn't have this feature.  That means that if a v1 peer connects to a v2 peer, when the v1 peer's coordinator connection is closed, the v2 peer will
see it as "lost" since we don't know whether the v1 peer meant to disconnect, or it just lost connectivity to the coordinator.
2023-12-01 09:55:25 +04:00

1117 lines
30 KiB
Go

package tailnet_test
import (
"context"
"database/sql"
"io"
"net"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agentID := uuid.New()
client := newTestClient(t, coordinator, agentID)
defer client.close()
client.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(clients) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(clients[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = client.close()
require.NoError(t, err)
<-client.errChan
<-client.closeChan
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "original")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
client := newTestClient(t, coordinator, agent.id)
defer client.close()
agentNodes := client.recvNodes(ctx, t)
require.Len(t, agentNodes, 1)
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
client.sendNode(&agpl.Node{PreferredDERP: 11})
clientNodes := agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Ensure an update to the agent node reaches the connIO!
agent.sendNode(&agpl.Node{PreferredDERP: 12})
agentNodes = client.recvNodes(ctx, t)
require.Len(t, agentNodes, 1)
assert.Equal(t, 12, agentNodes[0].PreferredDERP)
// Close the agent WebSocket so a new one can connect.
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
// Create a new agent connection. This is to simulate a reconnect!
agent = newTestAgent(t, coordinator, "reconnection", agent.id)
// Ensure the existing listening connIO sends its node immediately!
clientNodes = agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
// coordinator accidentally reordering things.
for d := 13; d < 36; d++ {
agent.sendNode(&agpl.Node{PreferredDERP: d})
}
for {
nodes := client.recvNodes(ctx, t)
if !assert.Len(t, nodes, 1) {
break
}
if nodes[0].PreferredDERP == 35 {
// got latest!
break
}
}
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
client := newTestClient(t, coordinator, agent.id)
defer client.close()
assertEventuallyHasDERPs(ctx, t, client, 10)
client.sendNode(&agpl.Node{PreferredDERP: 11})
assertEventuallyHasDERPs(ctx, t, agent, 11)
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
// real coordinator
fCoord2 := &fakeCoordinator{
ctx: ctx,
t: t,
store: store,
id: uuid.New(),
}
// heatbeat until canceled
ctx2, cancel2 := context.WithCancel(ctx)
go func() {
t := time.NewTicker(tailnet.HeartbeatPeriod)
defer t.Stop()
for {
select {
case <-ctx2.Done():
return
case <-t.C:
fCoord2.heartbeat()
}
}
}()
fCoord2.heartbeat()
fCoord2.agentNode(agent.id, &agpl.Node{PreferredDERP: 12})
assertEventuallyHasDERPs(ctx, t, client, 12)
fCoord3 := &fakeCoordinator{
ctx: ctx,
t: t,
store: store,
id: uuid.New(),
}
start := time.Now()
fCoord3.heartbeat()
fCoord3.agentNode(agent.id, &agpl.Node{PreferredDERP: 13})
assertEventuallyHasDERPs(ctx, t, client, 13)
// when the fCoord3 misses enough heartbeats, the real coordinator should send an update with the
// node from fCoord2 for the agent.
assertEventuallyHasDERPs(ctx, t, client, 12)
assert.Greater(t, time.Since(start), tailnet.HeartbeatPeriod*tailnet.MissedHeartbeats)
// stop fCoord2 heartbeats, which should cause us to revert to the original agent mapping
cancel2()
assertEventuallyHasDERPs(ctx, t, client, 10)
// send fCoord3 heartbeat, which should trigger us to consider that mapping valid again.
fCoord3.heartbeat()
assertEventuallyHasDERPs(ctx, t, client, 13)
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
mu := sync.Mutex{}
heartbeats := []time.Time{}
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
assert.NoError(t, err)
mu.Lock()
defer mu.Unlock()
heartbeats = append(heartbeats, time.Now())
})
require.NoError(t, err)
defer unsub()
start := time.Now()
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
if len(heartbeats) < 2 {
return false
}
require.Greater(t, heartbeats[0].Sub(start), time.Duration(0))
require.Greater(t, heartbeats[1].Sub(start), time.Duration(0))
return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*9/10)
}, testutil.WaitMedium, testutil.IntervalMedium)
}
// TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent.
//
// +---------+
// agent1 ---> | coord1 | <--- client11 (coord 1, agent 1)
// | |
// | | <--- client12 (coord 1, agent 2)
// +---------+
// +---------+
// agent2 ---> | coord2 | <--- client21 (coord 2, agent 1)
// | |
// | | <--- client22 (coord2, agent 2)
// +---------+
func TestPGCoordinatorDual_Mainline(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
t.Logf("agent1=%s", agent1.id)
agent2 := newTestAgent(t, coord2, "agent2")
defer agent2.close()
t.Logf("agent2=%s", agent2.id)
client11 := newTestClient(t, coord1, agent1.id)
defer client11.close()
t.Logf("client11=%s", client11.id)
client12 := newTestClient(t, coord1, agent2.id)
defer client12.close()
t.Logf("client12=%s", client12.id)
client21 := newTestClient(t, coord2, agent1.id)
defer client21.close()
t.Logf("client21=%s", client21.id)
client22 := newTestClient(t, coord2, agent2.id)
defer client22.close()
t.Logf("client22=%s", client22.id)
t.Logf("client11 -> Node 11")
client11.sendNode(&agpl.Node{PreferredDERP: 11})
assertEventuallyHasDERPs(ctx, t, agent1, 11)
t.Logf("client21 -> Node 21")
client21.sendNode(&agpl.Node{PreferredDERP: 21})
assertEventuallyHasDERPs(ctx, t, agent1, 21)
t.Logf("client22 -> Node 22")
client22.sendNode(&agpl.Node{PreferredDERP: 22})
assertEventuallyHasDERPs(ctx, t, agent2, 22)
t.Logf("agent2 -> Node 2")
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertEventuallyHasDERPs(ctx, t, client22, 2)
assertEventuallyHasDERPs(ctx, t, client12, 2)
t.Logf("client12 -> Node 12")
client12.sendNode(&agpl.Node{PreferredDERP: 12})
assertEventuallyHasDERPs(ctx, t, agent2, 12)
t.Logf("agent1 -> Node 1")
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertEventuallyHasDERPs(ctx, t, client21, 1)
assertEventuallyHasDERPs(ctx, t, client11, 1)
t.Logf("close coord2")
err = coord2.Close()
require.NoError(t, err)
// this closes agent2, client22, client21
err = agent2.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client22.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client21.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
assertEventuallyNoAgents(ctx, t, store, agent2.id)
t.Logf("close coord1")
err = coord1.Close()
require.NoError(t, err)
// this closes agent1, client12, client11
err = agent1.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client12.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client11.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
// wait for all connections to close
err = agent1.close()
require.NoError(t, err)
agent1.waitForClose(ctx, t)
err = agent2.close()
require.NoError(t, err)
agent2.waitForClose(ctx, t)
err = client11.close()
require.NoError(t, err)
client11.waitForClose(ctx, t)
err = client12.close()
require.NoError(t, err)
client12.waitForClose(ctx, t)
err = client21.close()
require.NoError(t, err)
client21.waitForClose(ctx, t)
err = client22.close()
require.NoError(t, err)
client22.waitForClose(ctx, t)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id)
}
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection,
// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started.
//
// +---------+
// agent1 ---> | coord1 |
// +---------+
// +---------+
// agent2 ---> | coord2 |
// +---------+
// +---------+
// | coord3 | <--- client
// +---------+
func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
require.NoError(t, err)
defer coord3.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent2 := newTestAgent(t, coord2, "agent2", agent1.id)
defer agent2.close()
client := newTestClient(t, coord3, agent1.id)
defer client.close()
client.sendNode(&agpl.Node{PreferredDERP: 3})
assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertEventuallyHasDERPs(ctx, t, client, 1)
// agent2's update overrides agent1 because it is newer
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertEventuallyHasDERPs(ctx, t, client, 2)
// agent2 disconnects, and we should revert back to agent1
err = agent2.close()
require.NoError(t, err)
err = agent2.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent2.waitForClose(ctx, t)
assertEventuallyHasDERPs(ctx, t, client, 1)
agent1.sendNode(&agpl.Node{PreferredDERP: 11})
assertEventuallyHasDERPs(ctx, t, client, 11)
client.sendNode(&agpl.Node{PreferredDERP: 31})
assertEventuallyHasDERPs(ctx, t, agent1, 31)
err = agent1.close()
require.NoError(t, err)
err = agent1.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent1.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
err = client.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, client.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
func TestPGCoordinator_Unhealthy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ps := pubsub.NewInMemory()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
calls := make(chan struct{})
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
Times(3).
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
MinTimes(1).
After(threeMissed).
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
Return(database.TailnetCoordinator{}, nil)
// extra calls we don't particularly care about for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()).
AnyTimes().Return(nil, nil)
mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()).
AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil)
mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
require.NoError(t, err)
defer func() {
err := uut.Close()
require.NoError(t, err)
}()
agent1 := newTestAgent(t, uut, "agent1")
defer agent1.close()
for i := 0; i < 3; i++ {
select {
case <-ctx.Done():
t.Fatal("timeout")
case calls <- struct{}{}:
// OK
}
}
// connected agent should be disconnected
agent1.waitForClose(ctx, t)
// new agent should immediately disconnect
agent2 := newTestAgent(t, uut, "agent2")
defer agent2.close()
agent2.waitForClose(ctx, t)
// next heartbeats succeed, so we are healthy
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
t.Fatal("timeout")
case calls <- struct{}{}:
// OK
}
}
agent3 := newTestAgent(t, uut, "agent3")
defer agent3.close()
select {
case <-agent3.closeChan:
t.Fatal("agent conn closed after we are healthy")
case <-time.After(time.Second):
// OK
}
}
// TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't
// do this now, but it's schematically possible, so we should make sure it doesn't break anything.
func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p2.addTunnel(p1.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
}
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.disconnect()
p1.assertEventuallyDisconnected(p2.id)
p2.assertEventuallyResponsesClosed()
}
func TestPGCoordinator_Lost(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.close(ctx)
p1.assertEventuallyLost(p2.id)
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node
sendNode func(node *agpl.Node)
errChan <-chan error
id uuid.UUID
closeChan chan struct{}
}
func newTestConn(ids []uuid.UUID) *testConn {
a := &testConn{}
a.ws, a.serverWS = net.Pipe()
a.nodeChan = make(chan []*agpl.Node)
a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error {
a.nodeChan <- nodes
return nil
})
if len(ids) > 1 {
panic("too many")
}
if len(ids) == 1 {
a.id = ids[0]
} else {
a.id = uuid.New()
}
a.closeChan = make(chan struct{})
return a
}
func newTestAgent(t *testing.T, coord agpl.Coordinator, name string, id ...uuid.UUID) *testConn {
a := newTestConn(id)
go func() {
err := coord.ServeAgent(a.serverWS, a.id, name)
assert.NoError(t, err)
close(a.closeChan)
}()
return a
}
func (c *testConn) close() error {
return c.ws.Close()
}
func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
t.Helper()
select {
case <-ctx.Done():
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
return nil
case nodes := <-c.nodeChan:
return nodes
}
}
func (c *testConn) recvErr(ctx context.Context, t *testing.T) error {
t.Helper()
// pgCoord works on eventual consistency, so it sometimes sends extra node
// updates, and these block errors if not read from the nodes channel.
for {
select {
case nodes := <-c.nodeChan:
t.Logf("ignoring nodes update while waiting for error; id=%s, nodes=%+v",
c.id.String(), nodes)
continue
case <-ctx.Done():
t.Fatal("timeout receiving error")
return ctx.Err()
case err := <-c.errChan:
return err
}
}
}
func (c *testConn) waitForClose(ctx context.Context, t *testing.T) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connection to close")
return
case <-c.closeChan:
return
}
}
func newTestClient(t *testing.T, coord agpl.Coordinator, agentID uuid.UUID, id ...uuid.UUID) *testConn {
c := newTestConn(id)
go func() {
err := coord.ServeClient(c.serverWS, c.id, agentID)
assert.NoError(t, err)
close(c.closeChan)
}()
return c
}
func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
t.Helper()
for {
nodes := c.recvNodes(ctx, t)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
t.Helper()
for {
select {
case <-ctx.Done():
return
case nodes := <-c.nodeChan:
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if slices.Contains(derps, e) {
t.Fatalf("expected not to get DERP %d, but received it", e)
return
}
}
}
}
}
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
require.True(t, ok)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
if !ok {
return
}
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
if err != nil {
t.Fatal(err)
}
return len(agents) == 0
}, testutil.WaitShort, testutil.IntervalFast)
}
func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
peers, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return false
}
if err != nil {
t.Fatal(err)
}
for _, peer := range peers {
if peer.Status == database.TailnetStatusOk {
return false
}
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
}
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
if err != nil {
t.Fatal(err)
}
return len(clients) == 0
}, testutil.WaitShort, testutil.IntervalFast)
}
type peerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
}
type testPeer struct {
ctx context.Context
cancel context.CancelFunc
t testing.TB
id uuid.UUID
name string
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
peers map[uuid.UUID]peerStatus
}
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 {
t.Fatal("too many")
}
if len(id) == 1 {
p.id = id[0]
} else {
p.id = uuid.New()
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
return p
}
func (p *testPeer) addTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout adding tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) updateDERP(derp int32) {
p.t.Helper()
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
o, ok := p.peers[other]
if ok && o.preferredDERP == derp {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper()
for {
_, ok := p.peers[other]
if !ok {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyResponsesClosed() {
p.t.Helper()
for {
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if !assert.NoError(p.t, err) {
return
}
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *testPeer) handleOneResp() error {
select {
case <-p.ctx.Done():
return p.ctx.Err()
case resp, ok := <-p.resps:
if !ok {
return responsesClosed
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if err != nil {
return err
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
p.peers[id] = peerStatus{
preferredDERP: update.GetNode().GetPreferredDerp(),
status: update.Kind,
}
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}
}
}
return nil
}
func (p *testPeer) close(ctx context.Context) {
p.t.Helper()
p.cancel()
for {
select {
case <-ctx.Done():
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
return
case _, ok := <-p.resps:
if ok {
continue
}
return
}
}
}
type fakeCoordinator struct {
ctx context.Context
t *testing.T
store database.Store
id uuid.UUID
}
func (c *fakeCoordinator) heartbeat() {
c.t.Helper()
_, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id)
require.NoError(c.t, err)
}
func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) {
c.t.Helper()
pNode, err := agpl.NodeToProto(node)
require.NoError(c.t, err)
nodeRaw, err := gProto.Marshal(pNode)
require.NoError(c.t, err)
_, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{
ID: agentID,
CoordinatorID: c.id,
Node: nodeRaw,
Status: database.TailnetStatusOk,
})
require.NoError(c.t, err)
}