mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: set peers lost when disconnected from coordinator (#11681)
Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost. This covers CoderSDK (CLI client) and Agent. Next PR will cover MultiAgent (notably, `wsproxy`).
This commit is contained in:
+167
-5
@@ -6,19 +6,24 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
|
||||
require.True(t, ok)
|
||||
return client, server
|
||||
}
|
||||
|
||||
func TestInMemoryCoordination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clientID := uuid.UUID{1}
|
||||
agentID := uuid.UUID{2}
|
||||
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
fConn := &fakeCoordinatee{}
|
||||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||
defer uut.Close()
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCoordination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clientID := uuid.UUID{1}
|
||||
agentID := uuid.UUID{2}
|
||||
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
fConn := &fakeCoordinatee{}
|
||||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
var coord tailnet.Coordinator = mCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
svc, err := tailnet.NewClientService(
|
||||
logger.Named("svc"), &coordPtr,
|
||||
time.Hour,
|
||||
func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
)
|
||||
require.NoError(t, err)
|
||||
sC, cC := net.Pipe()
|
||||
|
||||
serveErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
|
||||
serveErr <- err
|
||||
}()
|
||||
|
||||
client, err := tailnet.NewDRPCClient(cC)
|
||||
require.NoError(t, err)
|
||||
protocol, err := client.Coordinate(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
||||
defer uut.Close()
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
// coordinationTest tests that a coordination behaves correctly
|
||||
func coordinationTest(
|
||||
ctx context.Context, t *testing.T,
|
||||
uut tailnet.Coordination, fConn *fakeCoordinatee,
|
||||
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
|
||||
agentID uuid.UUID,
|
||||
) {
|
||||
// It should add the tunnel, since we configured as a client
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
||||
|
||||
// when we call the callback, it should send a node update
|
||||
require.NotNil(t, fConn.callback)
|
||||
fConn.callback(&tailnet.Node{PreferredDERP: 1})
|
||||
|
||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
|
||||
|
||||
// When we send a peer update, it should update the coordinatee
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dk, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
updates := []*proto.CoordinateResponse_PeerUpdate{
|
||||
{
|
||||
Id: agentID[:],
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Node: &proto.Node{
|
||||
Id: 2,
|
||||
Key: nk,
|
||||
Disco: string(dk),
|
||||
},
|
||||
},
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
|
||||
require.Eventually(t, func() bool {
|
||||
fConn.Lock()
|
||||
defer fConn.Unlock()
|
||||
return len(fConn.updates) > 0
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Len(t, fConn.updates[0], 1)
|
||||
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
||||
|
||||
err = uut.Close()
|
||||
require.NoError(t, err)
|
||||
uut.Error()
|
||||
|
||||
// When we close, it should gracefully disconnect
|
||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, req.Disconnect)
|
||||
|
||||
// It should set all peers lost on the coordinatee
|
||||
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
||||
}
|
||||
|
||||
type fakeCoordinatee struct {
|
||||
sync.Mutex
|
||||
callback func(*tailnet.Node)
|
||||
updates [][]*proto.CoordinateResponse_PeerUpdate
|
||||
setAllPeersLostCalls int
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.updates = append(f.updates, updates)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) SetAllPeersLost() {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.setAllPeersLostCalls++
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.callback = callback
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user