Files
coder/tailnet/coordinator_test.go
T
Spike Curtis 520b12e1a2 fix: close MultiAgentConn when coordinator closes (#11941)
Fixes an issue where a MultiAgentConn isn't closed properly when the coordinator it is connected to is closed.

Since servertailnet checks whether the conn is closed before reinitializing, it is important that we check this, otherwise servertailnet can get stuck if the coordinator closes (e.g. when we switch from AGPL to PGCoordinator after decoding a license).
2024-01-31 00:38:19 +04:00

583 lines
18 KiB
Go

package tailnet_test
import (
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
func TestCoordinator(t *testing.T) {
t.Parallel()
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
// in this test we use real websockets to test use of deadlines
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
agentWS, agentServerWS := websocketConn(ctx, t)
defer agentWS.Close()
agentNodeChan := make(chan []*tailnet.Node)
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
agentNodeChan <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := websocketConn(ctx, t)
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
// wait longer than the internal wait timeout.
// this tests for regression of https://github.com/coder/coder/issues/7428
time.Sleep(tailnet.WriteTimeout * 3 / 2)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
err := agentWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
defer agentWS.Close()
agentNodeChan = make(chan []*tailnet.Node)
_, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
agentNodeChan <- nodes
return nil
})
closeAgentChan = make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
close(closeAgentChan)
}()
// Ensure the existing listening client sends its node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitLong)
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
t.Logf("agent1 got node update: %v", nodes)
agentNodeChan1 <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan1 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS1, agentID, "")
assert.NoError(t, err)
close(closeAgentChan1)
}()
sendAgentNode1(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
t.Logf("client got node update: %v", nodes)
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one.
agentWS2, agentServerWS2 := net.Pipe()
defer agentWS2.Close()
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
t.Logf("agent2 got node update: %v", nodes)
agentNodeChan2 <- nodes
return nil
})
closeAgentChan2 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS2, agentID, "")
assert.NoError(t, err)
close(closeAgentChan2)
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
require.Len(t, clientNodes, 1)
// This original agent websocket should've been closed forcefully.
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
err := agentWS2.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
}
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
// https://github.com/coder/coder/issues/7295
func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
agentWS, agentServerWS := net.Pipe()
defer agentWS.Close()
agentID := uuid.New()
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
}()
// send an agent update before the client connects so that there is
// node data available to send right away.
aNode := tailnet.Node{PreferredDERP: 0}
aData, err := json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
// Connect from the client
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
clientID := uuid.New()
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
}()
// peek one byte from the node update, so we know the coordinator is
// trying to write to the client.
// buffer needs to be 2 characters longer because return value is a list
// so, it needs [ and ]
buf := make([]byte, len(aData)+2)
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err := clientWS.Read(buf[:1])
require.NoError(t, err)
require.Equal(t, 1, n)
// send a second update
aNode.PreferredDERP = 1
require.NoError(t, err)
aData, err = json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
// read the rest of the update from the client, should be initial node.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf[1:])
require.NoError(t, err)
var cNodes []*tailnet.Node
err = json.Unmarshal(buf[:n+1], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 0, cNodes[0].PreferredDERP)
// read second update
// without a fix for https://github.com/coder/coder/issues/7295 our
// read would time out here.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf)
require.NoError(t, err)
err = json.Unmarshal(buf[:n], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 1, cNodes[0].PreferredDERP)
}
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.BidirectionalTunnels(ctx, t, coordinator)
}
func TestCoordinator_GracefulDisconnect(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.GracefulDisconnectTest(ctx, t, coordinator)
}
func TestCoordinator_Lost(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.LostTest(ctx, t, coordinator)
}
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
err := coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
}
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
t.Helper()
sc := make(chan net.Conn, 1)
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
wss, err := websocket.Accept(rw, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(r.Context(), wss, websocket.MessageBinary)
sc <- conn
close(sc) // there can be only one
// hold open until context canceled
<-ctx.Done()
}))
t.Cleanup(s.Close)
// nolint: bodyclose
wsc, _, err := websocket.Dial(ctx, s.URL, nil)
require.NoError(t, err)
client = websocket.NetConn(ctx, wsc, websocket.MessageBinary)
server, ok := <-sc
require.True(t, ok)
return client, server
}
func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}
func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
err = uut.Close()
require.NoError(t, err)
uut.Error()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}
type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
}
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}
func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}