Files
coder/tailnet/coordinator_test.go
T
Spike Curtis d6154c4310 chore: remove tailnet v1 API support (#14641)
Drops support for v1 of the tailnet API, which was the original coordination protocol where we only sent node updates, never marked them lost or disconnected.

v2 of the tailnet API went GA for CLI clients in Coder 2.8.0, so clients older than that would stop working.
2024-09-12 07:56:31 +04:00

515 lines
15 KiB
Go

package tailnet_test
import (
"context"
"net"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
func TestCoordinator(t *testing.T) {
t.Parallel()
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
defer client.Close(ctx)
client.UpdateNode(&proto.Node{
Addresses: []string{netip.PrefixFrom(tailnet.IP(), 128).String()},
PreferredDerp: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(client.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
defer client.Close(ctx)
client.UpdateNode(&proto.Node{
Addresses: []string{
netip.PrefixFrom(tailnet.IP(), 64).String(),
},
PreferredDerp: 10,
})
client.AssertEventuallyResponsesClosed()
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
agent := test.NewAgent(ctx, t, coordinator, "agent")
defer agent.Close(ctx)
agent.UpdateNode(&proto.Node{
Addresses: []string{
netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 128).String(),
},
PreferredDerp: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(agent.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
agent := test.NewAgent(ctx, t, coordinator, "agent")
defer agent.Close(ctx)
agent.UpdateNode(&proto.Node{
Addresses: []string{
netip.PrefixFrom(tailnet.IP(), 128).String(),
},
PreferredDerp: 10,
})
agent.AssertEventuallyResponsesClosed()
})
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
agent := test.NewAgent(ctx, t, coordinator, "agent")
defer agent.Close(ctx)
agent.UpdateNode(&proto.Node{
Addresses: []string{
netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 64).String(),
},
PreferredDerp: 10,
})
agent.AssertEventuallyResponsesClosed()
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
agent := test.NewAgent(ctx, t, coordinator, "agent")
defer agent.Close(ctx)
agent.UpdateDERP(1)
require.Eventually(t, func() bool {
return coordinator.Node(agent.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client := test.NewClient(ctx, t, coordinator, "client", agent.ID)
defer client.Close(ctx)
client.AssertEventuallyHasDERP(agent.ID, 1)
client.UpdateDERP(2)
agent.AssertEventuallyHasDERP(client.ID, 2)
// Ensure an update to the agent node reaches the client!
agent.UpdateDERP(3)
client.AssertEventuallyHasDERP(agent.ID, 3)
// Close the agent so a new one can connect.
agent.Close(ctx)
// Create a new agent connection. This is to simulate a reconnect!
agent = test.NewPeer(ctx, t, coordinator, "agent", test.WithID(agent.ID))
defer agent.Close(ctx)
// Ensure the agent gets the existing client node immediately!
agent.AssertEventuallyHasDERP(client.ID, 2)
})
t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
agentID := uuid.New()
agent1 := test.NewPeer(ctx, t, coordinator, "agent1", test.WithID(agentID))
defer agent1.Close(ctx)
agent1.UpdateDERP(1)
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client := test.NewPeer(ctx, t, coordinator, "client")
defer client.Close(ctx)
client.AddTunnel(agentID)
client.AssertEventuallyHasDERP(agent1.ID, 1)
client.UpdateDERP(2)
agent1.AssertEventuallyHasDERP(client.ID, 2)
// Ensure an update to the agent node reaches the client!
agent1.UpdateDERP(3)
client.AssertEventuallyHasDERP(agent1.ID, 3)
// Create a new agent connection without disconnecting the old one.
agent2 := test.NewPeer(ctx, t, coordinator, "agent2", test.WithID(agentID))
defer agent2.Close(ctx)
// Ensure the existing client node gets sent immediately!
agent2.AssertEventuallyHasDERP(client.ID, 2)
// This original agent channels should've been closed forcefully.
agent1.AssertEventuallyResponsesClosed()
})
t.Run("AgentAck", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.ReadyForHandshakeTest(ctx, t, coordinator)
})
t.Run("AgentAck_NoPermission", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.ReadyForHandshakeNoPermissionTest(ctx, t, coordinator)
})
}
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.BidirectionalTunnels(ctx, t, coordinator)
}
func TestCoordinator_GracefulDisconnect(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.GracefulDisconnectTest(ctx, t, coordinator)
}
func TestCoordinator_Lost(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.LostTest(ctx, t, coordinator)
}
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
err := coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
}
func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}
func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
defer uut.Close()
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: clientID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 3,
Key: nk,
Disco: string(dk),
},
}},
})
rfh := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, rfh.ReadyForHandshake)
require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
require.NoError(t, uut.Close())
select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
err = uut.Close()
require.NoError(t, err)
uut.Error()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}
type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
tunnelDestinations map[uuid.UUID]struct{}
}
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}
func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}
func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) {
f.Lock()
defer f.Unlock()
if f.tunnelDestinations == nil {
f.tunnelDestinations = map[uuid.UUID]struct{}{}
}
f.tunnelDestinations[id] = struct{}{}
}
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}