mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
bddb808b25
Fixes all our Go file imports to match the preferred spec that we've _mostly_ been using. For example: ``` import ( "context" "time" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" "cdr.dev/slog/v3" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/serpent" ) ``` 3 groups: standard library, 3rd partly libs, Coder libs. This PR makes the change across the codebase. The PR in the stack above modifies our formatting to maintain this state of affairs, and is a separate PR so it's possible to review that one in detail.
975 lines
31 KiB
Go
975 lines
31 KiB
Go
package tailnet_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/netip"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
gProto "google.golang.org/protobuf/proto"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/enterprise/tailnet"
|
|
agpl "github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
agpltest "github.com/coder/coder/v2/tailnet/test"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agentID := uuid.New()
|
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agentID)
|
|
defer client.Close(ctx)
|
|
client.UpdateDERP(10)
|
|
require.Eventually(t, func() bool {
|
|
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("database error: %v", err)
|
|
}
|
|
if len(clients) == 0 {
|
|
return false
|
|
}
|
|
node := new(proto.Node)
|
|
err = gProto.Unmarshal(clients[0].Node, node)
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
|
return true
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
client.UngracefulDisconnect(ctx)
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateDERP(10)
|
|
require.Eventually(t, func() bool {
|
|
agents, err := store.GetTailnetPeers(ctx, agent.ID)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("database error: %v", err)
|
|
}
|
|
if len(agents) == 0 {
|
|
return false
|
|
}
|
|
node := new(proto.Node)
|
|
err = gProto.Unmarshal(agents[0].Node, node)
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
|
return true
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
agent.UngracefulDisconnect(ctx)
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
|
defer agent.Close(ctx)
|
|
prefix := agpl.TailscaleServicePrefix.RandomPrefix()
|
|
agent.UpdateNode(&proto.Node{
|
|
Addresses: []string{prefix.String()},
|
|
PreferredDerp: 10,
|
|
})
|
|
|
|
// The agent connection should be closed immediately after sending an invalid addr
|
|
agent.AssertEventuallyResponsesClosed(
|
|
agpl.AuthorizationError{Wrapped: agpl.InvalidNodeAddressError{Addr: prefix.Addr().String()}}.Error())
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateNode(&proto.Node{
|
|
Addresses: []string{
|
|
netip.PrefixFrom(agpl.TailscaleServicePrefix.AddrFromUUID(agent.ID), 64).String(),
|
|
},
|
|
PreferredDerp: 10,
|
|
})
|
|
|
|
// The agent connection should be closed immediately after sending an invalid addr
|
|
agent.AssertEventuallyResponsesClosed(
|
|
agpl.AuthorizationError{Wrapped: agpl.InvalidAddressBitsError{Bits: 64}}.Error())
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateNode(&proto.Node{
|
|
Addresses: []string{
|
|
agpl.TailscaleServicePrefix.PrefixFromUUID(agent.ID).String(),
|
|
},
|
|
PreferredDerp: 10,
|
|
})
|
|
require.Eventually(t, func() bool {
|
|
agents, err := store.GetTailnetPeers(ctx, agent.ID)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
t.Fatalf("database error: %v", err)
|
|
}
|
|
if len(agents) == 0 {
|
|
return false
|
|
}
|
|
node := new(proto.Node)
|
|
err = gProto.Unmarshal(agents[0].Node, node)
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
|
return true
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
agent.UngracefulDisconnect(ctx)
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "original")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateDERP(10)
|
|
|
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
|
defer client.Close(ctx)
|
|
|
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
|
client.UpdateDERP(11)
|
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
|
|
|
// Ensure an update to the agent node reaches the connIO!
|
|
agent.UpdateDERP(12)
|
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
|
|
|
// Close the agent channel so a new one can connect.
|
|
agent.Close(ctx)
|
|
|
|
// Create a new agent connection. This is to simulate a reconnect!
|
|
agent = agpltest.NewPeer(ctx, t, coordinator, "reconnection", agpltest.WithID(agent.ID))
|
|
// Ensure the coordinator sends its client node immediately!
|
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
|
|
|
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
|
|
// coordinator accidentally reordering things.
|
|
for d := int32(13); d < 36; d++ {
|
|
agent.UpdateDERP(d)
|
|
}
|
|
client.AssertEventuallyHasDERP(agent.ID, 35)
|
|
|
|
agent.UngracefulDisconnect(ctx)
|
|
client.UngracefulDisconnect(ctx)
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
mClock := quartz.NewMock(t)
|
|
afTrap := mClock.Trap().AfterFunc("heartbeats", "recvBeat")
|
|
defer afTrap.Close()
|
|
rstTrap := mClock.Trap().TimerReset("heartbeats", "resetExpiryTimerWithLock")
|
|
defer rstTrap.Close()
|
|
|
|
coordinator, err := tailnet.NewTestPGCoord(ctx, logger, ps, store, mClock)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateDERP(10)
|
|
|
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
|
defer client.Close(ctx)
|
|
|
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
|
client.UpdateDERP(11)
|
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
|
|
|
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
|
|
// real coordinator
|
|
fCoord2 := &fakeCoordinator{
|
|
ctx: ctx,
|
|
t: t,
|
|
store: store,
|
|
id: uuid.New(),
|
|
}
|
|
|
|
fCoord2.heartbeat()
|
|
afTrap.MustWait(ctx).MustRelease(ctx) // heartbeat timeout started
|
|
|
|
fCoord2.agentNode(agent.ID, &agpl.Node{PreferredDERP: 12})
|
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
|
|
|
fCoord3 := &fakeCoordinator{
|
|
ctx: ctx,
|
|
t: t,
|
|
store: store,
|
|
id: uuid.New(),
|
|
}
|
|
fCoord3.heartbeat()
|
|
rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset
|
|
fCoord3.agentNode(agent.ID, &agpl.Node{PreferredDERP: 13})
|
|
client.AssertEventuallyHasDERP(agent.ID, 13)
|
|
|
|
// fCoord2 sends in a second heartbeat, one period later (on time)
|
|
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
|
fCoord2.heartbeat()
|
|
rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset
|
|
|
|
// when the fCoord3 misses enough heartbeats, the real coordinator should send an update with the
|
|
// node from fCoord2 for the agent.
|
|
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
|
w := mClock.Advance(tailnet.HeartbeatPeriod)
|
|
rstTrap.MustWait(ctx).MustRelease(ctx)
|
|
w.MustWait(ctx)
|
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
|
|
|
// one more heartbeat period will result in fCoord2 being expired, which should cause us to
|
|
// revert to the original agent mapping
|
|
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
|
// note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired
|
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
|
|
|
// send fCoord3 heartbeat, which should trigger us to consider that mapping valid again.
|
|
fCoord3.heartbeat()
|
|
rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset
|
|
client.AssertEventuallyHasDERP(agent.ID, 13)
|
|
|
|
agent.UngracefulDisconnect(ctx)
|
|
client.UngracefulDisconnect(ctx)
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agentID := uuid.New()
|
|
|
|
client := agpltest.NewPeer(ctx, t, coordinator, "client")
|
|
defer client.Close(ctx)
|
|
client.AddTunnel(agentID)
|
|
|
|
client.UpdateDERP(11)
|
|
|
|
// simulate a second coordinator via DB calls only --- our goal is to test
|
|
// broken heart-beating, so we can't use a real coordinator
|
|
fCoord2 := &fakeCoordinator{
|
|
ctx: ctx,
|
|
t: t,
|
|
store: store,
|
|
id: uuid.New(),
|
|
}
|
|
// simulate a single heartbeat, the coordinator is healthy
|
|
fCoord2.heartbeat()
|
|
|
|
fCoord2.agentNode(agentID, &agpl.Node{PreferredDERP: 12})
|
|
// since it's healthy the client should get the new node.
|
|
client.AssertEventuallyHasDERP(agentID, 12)
|
|
|
|
// the heartbeat should then timeout and we'll get sent a LOST update, NOT a
|
|
// disconnect.
|
|
client.AssertEventuallyLost(agentID)
|
|
|
|
client.UngracefulDisconnect(ctx)
|
|
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
}
|
|
|
|
func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
|
|
mu := sync.Mutex{}
|
|
heartbeats := []time.Time{}
|
|
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, _ []byte, err error) {
|
|
assert.NoError(t, err)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
heartbeats = append(heartbeats, time.Now())
|
|
})
|
|
require.NoError(t, err)
|
|
defer unsub()
|
|
|
|
start := time.Now()
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
require.Eventually(t, func() bool {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if len(heartbeats) < 2 {
|
|
return false
|
|
}
|
|
assert.Greater(t, heartbeats[0].Sub(start), time.Duration(0))
|
|
assert.Greater(t, heartbeats[1].Sub(start), time.Duration(0))
|
|
return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*3/4)
|
|
}, testutil.WaitMedium, testutil.IntervalMedium)
|
|
}
|
|
|
|
// TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent.
|
|
//
|
|
// +---------+
|
|
// agent1 ---> | coord1 | <--- client11 (coord 1, agent 1)
|
|
// | |
|
|
// | | <--- client12 (coord 1, agent 2)
|
|
// +---------+
|
|
// +---------+
|
|
// agent2 ---> | coord2 | <--- client21 (coord 2, agent 1)
|
|
// | |
|
|
// | | <--- client22 (coord2, agent 2)
|
|
// +---------+
|
|
func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
|
require.NoError(t, err)
|
|
defer coord1.Close()
|
|
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
|
require.NoError(t, err)
|
|
defer coord2.Close()
|
|
|
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
|
defer agent1.Close(ctx)
|
|
t.Logf("agent1=%s", agent1.ID)
|
|
agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2")
|
|
defer agent2.Close(ctx)
|
|
t.Logf("agent2=%s", agent2.ID)
|
|
|
|
client11 := agpltest.NewClient(ctx, t, coord1, "client11", agent1.ID)
|
|
defer client11.Close(ctx)
|
|
t.Logf("client11=%s", client11.ID)
|
|
client12 := agpltest.NewClient(ctx, t, coord1, "client12", agent2.ID)
|
|
defer client12.Close(ctx)
|
|
t.Logf("client12=%s", client12.ID)
|
|
client21 := agpltest.NewClient(ctx, t, coord2, "client21", agent1.ID)
|
|
defer client21.Close(ctx)
|
|
t.Logf("client21=%s", client21.ID)
|
|
client22 := agpltest.NewClient(ctx, t, coord2, "client22", agent2.ID)
|
|
defer client22.Close(ctx)
|
|
t.Logf("client22=%s", client22.ID)
|
|
|
|
t.Log("client11 -> Node 11")
|
|
client11.UpdateDERP(11)
|
|
agent1.AssertEventuallyHasDERP(client11.ID, 11)
|
|
|
|
t.Log("client21 -> Node 21")
|
|
client21.UpdateDERP(21)
|
|
agent1.AssertEventuallyHasDERP(client21.ID, 21)
|
|
|
|
t.Log("client22 -> Node 22")
|
|
client22.UpdateDERP(22)
|
|
agent2.AssertEventuallyHasDERP(client22.ID, 22)
|
|
|
|
t.Log("agent2 -> Node 2")
|
|
agent2.UpdateDERP(2)
|
|
client22.AssertEventuallyHasDERP(agent2.ID, 2)
|
|
client12.AssertEventuallyHasDERP(agent2.ID, 2)
|
|
|
|
t.Log("client12 -> Node 12")
|
|
client12.UpdateDERP(12)
|
|
agent2.AssertEventuallyHasDERP(client12.ID, 12)
|
|
|
|
t.Log("agent1 -> Node 1")
|
|
agent1.UpdateDERP(1)
|
|
client21.AssertEventuallyHasDERP(agent1.ID, 1)
|
|
client11.AssertEventuallyHasDERP(agent1.ID, 1)
|
|
|
|
t.Log("close coord2")
|
|
err = coord2.Close()
|
|
require.NoError(t, err)
|
|
|
|
// this closes agent2, client22, client21
|
|
agent2.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
client22.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
client21.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
assertEventuallyLost(ctx, t, store, agent2.ID)
|
|
assertEventuallyLost(ctx, t, store, client21.ID)
|
|
assertEventuallyLost(ctx, t, store, client22.ID)
|
|
|
|
err = coord1.Close()
|
|
require.NoError(t, err)
|
|
// this closes agent1, client12, client11
|
|
agent1.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
client12.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
client11.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
|
assertEventuallyLost(ctx, t, store, client11.ID)
|
|
assertEventuallyLost(ctx, t, store, client12.ID)
|
|
}
|
|
|
|
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
|
|
// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection,
|
|
// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started.
|
|
//
|
|
// +---------+
|
|
// agent1 ---> | coord1 |
|
|
// +---------+
|
|
// +---------+
|
|
// agent2 ---> | coord2 |
|
|
// +---------+
|
|
// +---------+
|
|
// | coord3 | <--- client
|
|
// +---------+
|
|
func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
|
require.NoError(t, err)
|
|
defer coord1.Close()
|
|
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
|
require.NoError(t, err)
|
|
defer coord2.Close()
|
|
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
|
|
require.NoError(t, err)
|
|
defer coord3.Close()
|
|
|
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
|
defer agent1.Close(ctx)
|
|
agent2 := agpltest.NewPeer(ctx, t, coord2, "agent2",
|
|
agpltest.WithID(agent1.ID), agpltest.WithAuth(agpl.AgentCoordinateeAuth{ID: agent1.ID}),
|
|
)
|
|
defer agent2.Close(ctx)
|
|
|
|
client := agpltest.NewClient(ctx, t, coord3, "client", agent1.ID)
|
|
defer client.Close(ctx)
|
|
|
|
client.UpdateDERP(3)
|
|
agent1.AssertEventuallyHasDERP(client.ID, 3)
|
|
agent2.AssertEventuallyHasDERP(client.ID, 3)
|
|
|
|
agent1.UpdateDERP(1)
|
|
client.AssertEventuallyHasDERP(agent1.ID, 1)
|
|
|
|
// agent2's update overrides agent1 because it is newer
|
|
agent2.UpdateDERP(2)
|
|
client.AssertEventuallyHasDERP(agent1.ID, 2)
|
|
|
|
// agent2 disconnects, and we should revert back to agent1
|
|
agent2.Close(ctx)
|
|
client.AssertEventuallyHasDERP(agent1.ID, 1)
|
|
|
|
agent1.UpdateDERP(11)
|
|
client.AssertEventuallyHasDERP(agent1.ID, 11)
|
|
|
|
client.UpdateDERP(31)
|
|
agent1.AssertEventuallyHasDERP(client.ID, 31)
|
|
|
|
agent1.UngracefulDisconnect(ctx)
|
|
client.UngracefulDisconnect(ctx)
|
|
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
|
}
|
|
|
|
func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
ctrl := gomock.NewController(t)
|
|
mStore := dbmock.NewMockStore(ctrl)
|
|
ps := pubsub.NewInMemory()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
|
|
|
calls := make(chan struct{})
|
|
// first call succeeds, so that our Agent will successfully connect.
|
|
firstSucceeds := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
|
Times(1).
|
|
Return(database.TailnetCoordinator{}, nil)
|
|
// next 3 fail, so the Coordinator becomes unhealthy, and we test that it disconnects the agent
|
|
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
|
After(firstSucceeds).
|
|
Times(3).
|
|
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
|
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
|
|
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
|
MinTimes(1).
|
|
After(threeMissed).
|
|
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
|
Return(database.TailnetCoordinator{}, nil)
|
|
// extra calls we don't particularly care about for this test
|
|
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
|
|
mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()).
|
|
AnyTimes().Return(nil, nil)
|
|
mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()).
|
|
AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil)
|
|
mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any())
|
|
|
|
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
err := uut.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
agent1 := agpltest.NewAgent(ctx, t, uut, "agent1")
|
|
defer agent1.Close(ctx)
|
|
for i := 0; i < 3; i++ {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatalf("timeout waiting for call %d", i+1)
|
|
case calls <- struct{}{}:
|
|
// OK
|
|
}
|
|
}
|
|
// connected agent should be disconnected
|
|
agent1.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy)
|
|
|
|
// new agent should immediately disconnect
|
|
agent2 := agpltest.NewAgent(ctx, t, uut, "agent2")
|
|
defer agent2.Close(ctx)
|
|
agent2.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy)
|
|
|
|
// next heartbeats succeed, so we are healthy
|
|
for i := 0; i < 2; i++ {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatal("timeout")
|
|
case calls <- struct{}{}:
|
|
// OK
|
|
}
|
|
}
|
|
agent3 := agpltest.NewAgent(ctx, t, uut, "agent3")
|
|
defer agent3.Close(ctx)
|
|
agent3.AssertNotClosed(time.Second)
|
|
}
|
|
|
|
func TestPGCoordinator_Node_Empty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
ctrl := gomock.NewController(t)
|
|
mStore := dbmock.NewMockStore(ctrl)
|
|
ps := pubsub.NewInMemory()
|
|
logger := testutil.Logger(t)
|
|
|
|
id := uuid.New()
|
|
mStore.EXPECT().GetTailnetPeers(gomock.Any(), id).Times(1).Return(nil, nil)
|
|
|
|
// extra calls we don't particularly care about for this test
|
|
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
|
AnyTimes().
|
|
Return(database.TailnetCoordinator{}, nil)
|
|
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
|
|
mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()).Times(1)
|
|
|
|
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
err := uut.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
node := uut.Node(id)
|
|
require.Nil(t, node)
|
|
}
|
|
|
|
// TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't
|
|
// do this now, but it's schematically possible, so we should make sure it doesn't break anything.
|
|
func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
agpltest.BidirectionalTunnels(ctx, t, coordinator)
|
|
}
|
|
|
|
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
agpltest.GracefulDisconnectTest(ctx, t, coordinator)
|
|
}
|
|
|
|
func TestPGCoordinator_Lost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
agpltest.LostTest(ctx, t, coordinator)
|
|
}
|
|
|
|
func TestPGCoordinator_NoDeleteOnClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator.Close()
|
|
|
|
agent := agpltest.NewAgent(ctx, t, coordinator, "original")
|
|
defer agent.Close(ctx)
|
|
agent.UpdateDERP(10)
|
|
|
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
|
defer client.Close(ctx)
|
|
|
|
// Simulate some traffic to generate
|
|
// a peer.
|
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
|
client.UpdateDERP(11)
|
|
|
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
|
|
|
anode := coordinator.Node(agent.ID)
|
|
require.NotNil(t, anode)
|
|
cnode := coordinator.Node(client.ID)
|
|
require.NotNil(t, cnode)
|
|
|
|
err = coordinator.Close()
|
|
require.NoError(t, err)
|
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
|
assertEventuallyLost(ctx, t, store, client.ID)
|
|
|
|
coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer coordinator2.Close()
|
|
|
|
anode = coordinator2.Node(agent.ID)
|
|
require.NotNil(t, anode)
|
|
assert.Equal(t, 10, anode.PreferredDERP)
|
|
|
|
cnode = coordinator2.Node(client.ID)
|
|
require.NotNil(t, cnode)
|
|
assert.Equal(t, 11, cnode.PreferredDERP)
|
|
}
|
|
|
|
// TestPGCoordinatorDual_FailedHeartbeat tests that peers
|
|
// disconnect from a coordinator when they are unhealthy,
|
|
// are marked as LOST (not DISCONNECTED), and can reconnect to
|
|
// a new coordinator and reestablish their tunnels.
|
|
func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dburl, err := dbtestutil.Open(t)
|
|
require.NoError(t, err)
|
|
|
|
store1, ps1, sdb1 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl))
|
|
defer sdb1.Close()
|
|
store2, ps2, sdb2 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl))
|
|
defer sdb2.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
t.Cleanup(cancel)
|
|
|
|
// We do this to avoid failing due errors related to the
|
|
// database connection being close.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
|
|
|
// Create two coordinators, 1 for each peer.
|
|
c1, err := tailnet.NewPGCoord(ctx, logger, ps1, store1)
|
|
require.NoError(t, err)
|
|
c2, err := tailnet.NewPGCoord(ctx, logger, ps2, store2)
|
|
require.NoError(t, err)
|
|
|
|
p1 := agpltest.NewPeer(ctx, t, c1, "peer1")
|
|
p2 := agpltest.NewPeer(ctx, t, c2, "peer2")
|
|
|
|
// Create a binding between the two.
|
|
p1.AddTunnel(p2.ID)
|
|
|
|
// Ensure that messages pass through.
|
|
p1.UpdateDERP(1)
|
|
p2.UpdateDERP(2)
|
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
|
|
|
// Close the underlying database connection to induce
|
|
// a heartbeat failure scenario and assert that
|
|
// we eventually disconnect from the coordinator.
|
|
err = sdb1.Close()
|
|
require.NoError(t, err)
|
|
p1.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy)
|
|
p2.AssertEventuallyLost(p1.ID)
|
|
// This basically checks that peer2 had no update
|
|
// performed on their status since we are connected
|
|
// to coordinator2.
|
|
assertEventuallyStatus(ctx, t, store2, p2.ID, database.TailnetStatusOk)
|
|
|
|
// Connect peer1 to coordinator2.
|
|
p1.ConnectToCoordinator(ctx, c2)
|
|
// Reestablish binding.
|
|
p1.AddTunnel(p2.ID)
|
|
// Ensure messages still flow back and forth.
|
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
|
p1.UpdateDERP(3)
|
|
p2.UpdateDERP(4)
|
|
p2.AssertEventuallyHasDERP(p1.ID, 3)
|
|
p1.AssertEventuallyHasDERP(p2.ID, 4)
|
|
// Make sure peer2 never got an update about peer1 disconnecting.
|
|
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
|
|
}
|
|
|
|
func TestPGCoordinatorDual_PeerReconnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store, ps := dbtestutil.NewDB(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
defer cancel()
|
|
logger := testutil.Logger(t)
|
|
|
|
// Create two coordinators, 1 for each peer.
|
|
c1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
c2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
|
|
p1 := agpltest.NewPeer(ctx, t, c1, "peer1")
|
|
p2 := agpltest.NewPeer(ctx, t, c2, "peer2")
|
|
|
|
// Create a binding between the two.
|
|
p1.AddTunnel(p2.ID)
|
|
|
|
// Ensure that messages pass through.
|
|
p1.UpdateDERP(1)
|
|
p2.UpdateDERP(2)
|
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
|
|
|
// Close coordinator1. Now we will check that we
|
|
// never send a DISCONNECTED update.
|
|
err = c1.Close()
|
|
require.NoError(t, err)
|
|
p1.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose)
|
|
p2.AssertEventuallyLost(p1.ID)
|
|
// This basically checks that peer2 had no update
|
|
// performed on their status since we are connected
|
|
// to coordinator2.
|
|
assertEventuallyStatus(ctx, t, store, p2.ID, database.TailnetStatusOk)
|
|
|
|
// Connect peer1 to coordinator2.
|
|
p1.ConnectToCoordinator(ctx, c2)
|
|
// Reestablish binding.
|
|
p1.AddTunnel(p2.ID)
|
|
// Ensure messages still flow back and forth.
|
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
|
p1.UpdateDERP(3)
|
|
p2.UpdateDERP(4)
|
|
p2.AssertEventuallyHasDERP(p1.ID, 3)
|
|
p1.AssertEventuallyHasDERP(p2.ID, 4)
|
|
// Make sure peer2 never got an update about peer1 disconnecting.
|
|
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
|
|
}
|
|
|
|
// TestPGCoordinatorPropogatedPeerContext tests that the context for a specific peer
|
|
// is propogated through to the `Authorize` method of the coordinatee auth
|
|
func TestPGCoordinatorPropogatedPeerContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
store, ps := dbtestutil.NewDB(t)
|
|
logger := testutil.Logger(t)
|
|
|
|
peerCtx := context.WithValue(ctx, agpltest.FakeSubjectKey{}, struct{}{})
|
|
peerID := uuid.UUID{0x01}
|
|
agentID := uuid.UUID{0x02}
|
|
|
|
c1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
err := c1.Close()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
ch := make(chan struct{})
|
|
auth := agpltest.FakeCoordinateeAuth{
|
|
Chan: ch,
|
|
}
|
|
|
|
reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth)
|
|
|
|
testutil.RequireSend(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agpl.UUIDToByteSlice(agentID)}})
|
|
|
|
_ = testutil.TryReceive(ctx, t, ch)
|
|
}
|
|
|
|
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
|
|
t.Helper()
|
|
assert.Eventually(t, func() bool {
|
|
peers, err := store.GetTailnetPeers(ctx, agentID)
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return false
|
|
}
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, peer := range peers {
|
|
if peer.Status != status {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
}
|
|
|
|
func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
|
t.Helper()
|
|
assertEventuallyStatus(ctx, t, store, agentID, database.TailnetStatusLost)
|
|
}
|
|
|
|
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
|
t.Helper()
|
|
assert.Eventually(t, func() bool {
|
|
clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID)
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return true
|
|
}
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return len(clients) == 0
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
}
|
|
|
|
type fakeCoordinator struct {
|
|
ctx context.Context
|
|
t *testing.T
|
|
store database.Store
|
|
id uuid.UUID
|
|
}
|
|
|
|
func (c *fakeCoordinator) heartbeat() {
|
|
c.t.Helper()
|
|
_, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id)
|
|
require.NoError(c.t, err)
|
|
}
|
|
|
|
func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) {
|
|
c.t.Helper()
|
|
pNode, err := agpl.NodeToProto(node)
|
|
require.NoError(c.t, err)
|
|
nodeRaw, err := gProto.Marshal(pNode)
|
|
require.NoError(c.t, err)
|
|
_, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{
|
|
ID: agentID,
|
|
CoordinatorID: c.id,
|
|
Node: nodeRaw,
|
|
Status: database.TailnetStatusOk,
|
|
})
|
|
require.NoError(c.t, err)
|
|
}
|