chore: refactor ServerTailnet to use tailnet.Controllers (#15408)

chore of #14729

Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing.  This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
This commit is contained in:
Spike Curtis
2024-11-08 13:18:56 +04:00
committed by GitHub
parent f7cbf5dd79
commit 8c00ebc6ee
20 changed files with 491 additions and 1240 deletions
+1 -3
View File
@@ -507,7 +507,6 @@ gen: \
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/multiagentmock.go \
coderd/database/pubsub/psmock/psmock.go
.PHONY: gen
@@ -537,7 +536,6 @@ gen/mark-fresh:
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/multiagentmock.go \
coderd/database/pubsub/psmock/psmock.go \
"
@@ -570,7 +568,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
go generate ./coderd/database/pubsub/psmock
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/multiagentmock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go tailnet/multiagent.go
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
go generate ./tailnet/tailnettest/
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
+4 -2
View File
@@ -1919,7 +1919,8 @@ func TestAgent_UpdatedDERP(t *testing.T) {
t.Cleanup(testCtxCancel)
clientID := uuid.New()
ctrl := tailnet.NewSingleDestController(logger, conn, agentID)
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, coordinator))
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator))
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
@@ -2408,8 +2409,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(testCtxCancel)
clientID := uuid.New()
ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID)
auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
logger, clientID, metadata.AgentID, coordinator))
logger, clientID, auth, coordinator))
t.Cleanup(func() {
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
+7 -4
View File
@@ -627,14 +627,17 @@ func New(options *Options) *API {
api.Auditor.Store(&options.Auditor)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
dialer := &InmemTailnetDialer{
CoordPtr: &api.TailnetCoordinator,
DERPFn: api.DERPMap,
Logger: options.Logger,
ClientID: uuid.New(),
}
stn, err := NewServerTailnet(api.ctx,
options.Logger,
options.DERPServer,
api.DERPMap,
dialer,
options.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
func(context.Context) (tailnet.MultiAgentConn, error) {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
options.DeploymentValues.DERP.Config.BlockDirect.Value(),
api.TracerProvider,
)
+302 -231
View File
@@ -30,7 +30,7 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/retry"
"github.com/coder/coder/v2/tailnet/proto"
)
var tailnetTransport *http.Transport
@@ -53,9 +53,8 @@ func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMapFn func() *tailcfg.DERPMap,
dialer tailnet.ControlProtocolDialer,
derpForceWebSockets bool,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
blockEndpoints bool,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
@@ -91,46 +90,26 @@ func NewServerTailnet(
})
}
bgRoutines := &sync.WaitGroup{}
originalDerpMap := derpMapFn()
tracer := traceProvider.Tracer(tracing.TracerName)
controller := tailnet.NewController(logger, dialer)
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
// there is an embedded relay, we use the local in-memory dialer.
conn.SetDERPMap(originalDerpMap)
bgRoutines.Add(1)
go func() {
defer bgRoutines.Done()
defer logger.Debug(ctx, "polling DERPMap exited")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-serverCtx.Done():
return
case <-ticker.C:
}
newDerpMap := derpMapFn()
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
conn.SetDERPMap(newDerpMap)
originalDerpMap = newDerpMap
}
}
}()
controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn)
coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn)
controller.CoordCtrl = coordCtrl
// TODO: support controller.TelemetryCtrl
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
bgRoutines: bgRoutines,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
coordinatee: conn,
getMultiAgent: getMultiAgent,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
tracer: tracer,
conn: conn,
coordinatee: conn,
controller: controller,
coordCtrl: coordCtrl,
transport: tailnetTransport.Clone(),
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "servertailnet",
@@ -146,7 +125,7 @@ func NewServerTailnet(
}
tn.transport.DialContext = tn.dialContext
// These options are mostly just picked at random, and they can likely be
// fine tuned further. Generally, users are running applications in dev mode
// fine-tuned further. Generally, users are running applications in dev mode
// which can generate hundreds of requests per page load, so we increased
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
// conns.
@@ -164,23 +143,7 @@ func NewServerTailnet(
InsecureSkipVerify: true,
}
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)
// registering the callback also triggers send of the initial node
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
tn.bgRoutines.Add(2)
go func() {
defer tn.bgRoutines.Done()
tn.watchAgentUpdates()
}()
go func() {
defer tn.bgRoutines.Done()
tn.expireOldAgents()
}()
tn.controller.Run(tn.ctx)
return tn, nil
}
@@ -190,18 +153,6 @@ func (s *ServerTailnet) Conn() *tailnet.Conn {
return s.conn
}
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if err != nil {
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = s.getAgentConn().UpdateSelf(pn)
if err != nil {
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
}
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
s.connsPerAgent.Describe(descs)
s.totalConns.Describe(descs)
@@ -212,125 +163,9 @@ func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
s.totalConns.Collect(metrics)
}
func (s *ServerTailnet) expireOldAgents() {
defer s.logger.Debug(s.ctx, "stopped expiring old agents")
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.doExpireOldAgents(cutoff)
}
}
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
continue
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
}
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (s *ServerTailnet) watchAgentUpdates() {
defer s.logger.Debug(s.ctx, "stopped watching agent updates")
for {
conn := s.getAgentConn()
resp, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
s.coordinatee.SetAllPeersLost()
s.reinitCoordinator()
continue
}
return
}
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
return
}
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
start := time.Now()
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)
// reset the Node callback, which triggers the conn to send the node immediately, and also
// register for updates
s.coordinatee.SetNodeCallback(s.nodeCallback)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
slog.F("agents", len(s.agentConnectionTimes)),
slog.F("took", time.Since(start)),
)
s.nodesMu.Unlock()
return
}
}
type ServerTailnet struct {
ctx context.Context
cancel func()
bgRoutines *sync.WaitGroup
ctx context.Context
cancel func()
logger slog.Logger
tracer trace.Tracer
@@ -340,15 +175,8 @@ type ServerTailnet struct {
conn *tailnet.Conn
coordinatee tailnet.Coordinatee
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
nodesMu sync.Mutex
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
controller *tailnet.Controller
coordCtrl *MultiAgentController
transport *http.Transport
@@ -446,38 +274,6 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
}, nil
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentConnectionTimes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()
return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
var (
conn *workspacesdk.AgentConn
@@ -485,11 +281,11 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*work
)
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
err := s.coordCtrl.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
ret = s.coordCtrl.acquireTicket(agentID)
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
AgentID: agentID,
@@ -548,7 +344,8 @@ func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
s.bgRoutines.Wait()
s.coordCtrl.Close()
<-s.controller.Closed()
return nil
}
@@ -566,3 +363,277 @@ func (c *instrumentedConn) Close() error {
})
return c.Conn.Close()
}
// MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace
// agents. It keeps track of connection times to the agents, and removes them on a timer if they
// have no active connections and haven't been used in a while.
type MultiAgentController struct {
*tailnet.BasicCoordinationController
logger slog.Logger
tracer trace.Tracer
mu sync.Mutex
// connectionTimes is a map of agents the server wants to keep a connection to. It
// contains the last time the agent was connected to.
connectionTimes map[uuid.UUID]time.Time
// tickets is a map of destinations to a set of connection tickets, representing open
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
}
func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter {
b := m.BasicCoordinationController.NewCoordination(client)
// resync all destinations
m.mu.Lock()
defer m.mu.Unlock()
m.coordination = b
for agentID := range m.connectionTimes {
err := client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID),
slog.Error(err))
b.SendErr(err)
_ = client.Close()
m.coordination = nil
break
}
}
return b
}
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
m.tickets[agentID] = map[uuid.UUID]struct{}{}
}
m.connectionTimes[agentID] = time.Now()
return nil
}
func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
m.mu.Lock()
defer m.mu.Unlock()
m.tickets[agentID][id] = struct{}{}
return func() {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.tickets[agentID], id)
}
}
func (m *MultiAgentController) expireOldAgents(ctx context.Context) {
defer close(m.expireOldAgentsDone)
defer m.logger.Debug(context.Background(), "stopped expiring old agents")
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
m.doExpireOldAgents(ctx, cutoff)
}
}
func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := m.tracer.Start(ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
m.mu.Lock()
defer m.mu.Unlock()
m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes)))
for agentID, lastConnection := range m.connectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 {
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err))
// close the client because we do not want to do a graceful disconnect by
// closing the coordination.
_ = m.coordination.Client.Close()
m.coordination = nil
// Here we continue deleting any inactive agents: there is no point in
// re-establishing tunnels to expired agents when we eventually reconnect.
}
}
deletedCount++
delete(m.connectionTimes, agentID)
}
}
m.logger.Debug(ctx, "pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (m *MultiAgentController) Close() {
m.cancel()
<-m.expireOldAgentsDone
}
func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController {
m := &MultiAgentController{
BasicCoordinationController: &tailnet.BasicCoordinationController{
Logger: logger,
Coordinatee: coordinatee,
SendAcks: false, // we are a client, connecting to multiple agents
},
logger: logger,
tracer: tracer,
connectionTimes: make(map[uuid.UUID]time.Time),
tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}),
expireOldAgentsDone: make(chan struct{}),
}
ctx, m.cancel = context.WithCancel(ctx)
go m.expireOldAgents(ctx)
return m
}
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
// service running in the same memory space.
type InmemTailnetDialer struct {
CoordPtr *atomic.Pointer[tailnet.Coordinator]
DERPFn func() *tailcfg.DERPMap
Logger slog.Logger
ClientID uuid.UUID
}
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
coord := a.CoordPtr.Load()
if coord == nil {
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
}
coordClient := tailnet.NewInMemoryCoordinatorClient(
a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord)
derpClient := newPollingDERPClient(a.DERPFn, a.Logger)
return tailnet.ControlProtocolClients{
Closer: closeAll{coord: coordClient, derp: derpClient},
Coordinator: coordClient,
DERP: derpClient,
}, nil
}
func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient {
ctx, cancel := context.WithCancel(context.Background())
a := &pollingDERPClient{
fn: derpFn,
ctx: ctx,
cancel: cancel,
logger: logger,
ch: make(chan *tailcfg.DERPMap),
loopDone: make(chan struct{}),
}
go a.pollDERP()
return a
}
// pollingDERPClient is a DERP client that just calls a function on a polling
// interval
type pollingDERPClient struct {
fn func() *tailcfg.DERPMap
logger slog.Logger
ctx context.Context
cancel context.CancelFunc
loopDone chan struct{}
lastDERPMap *tailcfg.DERPMap
ch chan *tailcfg.DERPMap
}
// Close the DERP client
func (a *pollingDERPClient) Close() error {
a.cancel()
<-a.loopDone
return nil
}
func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) {
select {
case <-a.ctx.Done():
return nil, a.ctx.Err()
case dm := <-a.ch:
return dm, nil
}
}
func (a *pollingDERPClient) pollDERP() {
defer close(a.loopDone)
defer a.logger.Debug(a.ctx, "polling DERPMap exited")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-a.ctx.Done():
return
case <-ticker.C:
}
newDerpMap := a.fn()
if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) {
select {
case <-a.ctx.Done():
return
case a.ch <- newDerpMap:
}
}
}
}
type closeAll struct {
coord tailnet.CoordinatorClient
derp tailnet.DERPClient
}
func (c closeAll) Close() error {
cErr := c.coord.Close()
dErr := c.derp.Close()
if cErr != nil {
return cErr
}
return dErr
}
-77
View File
@@ -1,77 +0,0 @@
package coderd
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"go.uber.org/mock/gomock"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
// TestServerTailnet_Reconnect tests that ServerTailnet calls SetAllPeersLost on the Coordinatee
// (tailnet.Conn in production) when it disconnects from the Coordinator (via MultiAgentConn) and
// reconnects.
func TestServerTailnet_Reconnect(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
ctx := testutil.Context(t, testutil.WaitShort)
mMultiAgent0 := tailnettest.NewMockMultiAgentConn(ctrl)
mMultiAgent1 := tailnettest.NewMockMultiAgentConn(ctrl)
mac := make(chan tailnet.MultiAgentConn, 2)
mac <- mMultiAgent0
mac <- mMultiAgent1
mCoord := tailnettest.NewMockCoordinatee(ctrl)
uut := &ServerTailnet{
ctx: ctx,
logger: logger,
coordinatee: mCoord,
getMultiAgent: func(ctx context.Context) (tailnet.MultiAgentConn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case m := <-mac:
return m, nil
}
},
agentConn: atomic.Pointer[tailnet.MultiAgentConn]{},
agentConnectionTimes: make(map[uuid.UUID]time.Time),
}
// reinit the Coordinator once, to load mMultiAgent0
mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1)
uut.reinitCoordinator()
mMultiAgent0.EXPECT().NextUpdate(gomock.Any()).
Times(1).
Return(nil, false) // this indicates there are no more updates
closed0 := mMultiAgent0.EXPECT().IsClosed().
Times(1).
Return(true) // this triggers reconnect
setLost := mCoord.EXPECT().SetAllPeersLost().Times(1).After(closed0)
mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1).After(closed0)
mMultiAgent1.EXPECT().NextUpdate(gomock.Any()).
Times(1).
After(setLost).
Return(nil, false)
mMultiAgent1.EXPECT().IsClosed().
Times(1).
Return(false) // this causes us to exit and not reconnect
done := make(chan struct{})
go func() {
uut.watchAgentUpdates()
close(done)
}()
testutil.RequireRecvCtx(ctx, t, done)
}
+9 -2
View File
@@ -399,6 +399,8 @@ func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...tailnettest.DER
t.Cleanup(func() {
_ = coord.Close()
})
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
agents := []agentWithID{}
@@ -430,13 +432,18 @@ func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...tailnettest.DER
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
}
dialer := &coderd.InmemTailnetDialer{
CoordPtr: &coordPtr,
DERPFn: func() *tailcfg.DERPMap { return derpMap },
Logger: logger,
ClientID: uuid.UUID{5},
}
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
logger,
derpServer,
func() *tailcfg.DERPMap { return derpMap },
dialer,
false,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
!derpMap.HasSTUN(),
trace.NewNoopTracerProvider(),
)
+52 -53
View File
@@ -10,7 +10,6 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
@@ -39,19 +38,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
defer agent1.Close(ctx)
agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
defer ma1.Close(ctx)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close()
ma1.Disconnect()
agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
@@ -72,13 +71,13 @@ func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
require.NoError(t, err)
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
defer ma1.Close(ctx)
err = coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
ma1.AssertEventuallyResponsesClosed()
}
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
@@ -106,20 +105,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
defer agent1.Close(ctx)
agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
defer ma1.Close(ctx)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.RequireUnsubscribeAgent(agent1.ID)
ma1.Close()
ma1.RemoveTunnel(agent1.ID)
ma1.Close(ctx)
agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
@@ -151,35 +150,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
defer agent1.Close(ctx)
agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
defer ma1.Close(ctx)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.RequireUnsubscribeAgent(agent1.ID)
ma1.RemoveTunnel(agent1.ID)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
ma1.SendNodeWithDERP(9)
ma1.UpdateDERP(9)
agent1.AssertNeverHasDERPs(ctx, ma1.ID, 9)
}()
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
agent1.UpdateDERP(8)
ma1.RequireNeverHasDERPs(ctx, 8)
ma1.AssertNeverHasDERPs(ctx, agent1.ID, 8)
}()
ma1.Close()
ma1.Disconnect()
agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
@@ -216,19 +215,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
defer agent1.Close(ctx)
agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord2, "client")
defer ma1.Close(ctx)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close()
ma1.Disconnect()
agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
@@ -266,19 +265,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
defer agent1.Close(ctx)
agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord2, "client")
defer ma1.Close(ctx)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.Close()
ma1.Disconnect()
agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
@@ -325,26 +324,26 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
defer agent2.Close(ctx)
agent2.UpdateDERP(6)
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
defer ma1.Close()
ma1 := agpltest.NewPeer(ctx, t, coord3, "client")
defer ma1.Close(ctx)
ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5)
ma1.AddTunnel(agent1.ID)
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
ma1.RequireSubscribeAgent(agent2.ID)
ma1.RequireEventuallyHasDERPs(ctx, 6)
ma1.AddTunnel(agent2.ID)
ma1.AssertEventuallyHasDERP(agent2.ID, 6)
agent2.UpdateDERP(2)
ma1.RequireEventuallyHasDERPs(ctx, 2)
ma1.AssertEventuallyHasDERP(agent2.ID, 2)
ma1.SendNodeWithDERP(3)
ma1.UpdateDERP(3)
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
agent2.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close()
ma1.Disconnect()
agent1.UngracefulDisconnect(ctx)
agent2.UngracefulDisconnect(ctx)
-4
View File
@@ -165,10 +165,6 @@ func newPGCoordInternal(
return c, nil
}
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
return agpl.ServeMultiAgent(c, c.logger, id)
}
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
// We're going to directly query the database, since we would only have the mapping stored locally if we had
// a tunnel peer connected, which is not always the case.
+5 -13
View File
@@ -12,7 +12,6 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-chi/chi/v5"
@@ -24,7 +23,6 @@ import (
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
@@ -144,7 +142,6 @@ type Server struct {
replicaPingSingleflight singleflight.Group
replicaErrMut sync.Mutex
replicaErr string
latestDERPMap atomic.Pointer[tailcfg.DERPMap]
// Used for graceful shutdown. Required for the dialer.
ctx context.Context
@@ -271,14 +268,15 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
return nil, xerrors.Errorf("handle register: %w", err)
}
dialer, err := s.SDKClient.TailnetDialer()
if err != nil {
return nil, xerrors.Errorf("create tailnet dialer: %w", err)
}
agentProvider, err := coderd.NewServerTailnet(ctx,
s.Logger,
nil,
func() *tailcfg.DERPMap {
return s.latestDERPMap.Load()
},
dialer,
regResp.DERPForceWebSockets,
s.DialCoordinator,
opts.BlockDirect,
s.TracerProvider,
)
@@ -481,8 +479,6 @@ func (s *Server) handleRegister(res wsproxysdk.RegisterWorkspaceProxyResponse) e
s.Logger.Debug(s.ctx, "setting DERP mesh sibling addresses", slog.F("addresses", addresses))
s.derpMesh.SetAddresses(addresses, false)
s.latestDERPMap.Store(res.DERPMap)
go s.pingSiblingReplicas(res.SiblingReplicas)
return nil
}
@@ -569,10 +565,6 @@ func (s *Server) handleRegisterFailure(err error) {
)
}
func (s *Server) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) {
return s.SDKClient.DialCoordinator(ctx)
}
func (s *Server) buildInfo(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
ExternalURL: buildinfo.ExternalURL(),
+4 -108
View File
@@ -14,7 +14,6 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
@@ -501,13 +500,11 @@ type CoordinateNodes struct {
Nodes []*agpl.Node
}
func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
ctx, cancel := context.WithCancel(ctx)
logger := c.SDKClient.Logger().Named("multiagent")
func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
logger := c.SDKClient.Logger().Named("tailnet_dialer")
coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
if err != nil {
cancel()
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
@@ -520,59 +517,10 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
}
coordinateHeaders.Set(tokenHeader, c.SessionToken())
//nolint:bodyclose
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{
HTTPClient: c.SDKClient.HTTPClient,
HTTPHeader: coordinateHeaders,
})
if err != nil {
cancel()
return nil, xerrors.Errorf("dial coordinate websocket: %w", err)
}
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
client, err := agpl.NewDRPCClient(nc, logger)
if err != nil {
logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to create DRPCClient: %w", err)
}
protocol, err := client.Coordinate(ctx)
if err != nil {
logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err)
}
rma := remoteMultiAgentHandler{
sdk: c,
logger: logger,
protocol: protocol,
cancel: cancel,
}
ma := (&agpl.MultiAgent{
ID: uuid.New(),
OnSubscribe: rma.OnSubscribe,
OnUnsubscribe: rma.OnUnsubscribe,
OnNodeUpdate: rma.OnNodeUpdate,
OnRemove: rma.OnRemove,
}).Init()
go func() {
<-ctx.Done()
_ = ma.Close()
_ = client.DRPCConn().Close()
<-client.DRPCConn().Closed()
_ = conn.Close(websocket.StatusGoingAway, "closed")
}()
rma.ma = ma
go rma.respLoop()
return ma, nil
}), nil
}
type CryptoKeysResponse struct {
@@ -595,55 +543,3 @@ func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeatu
var resp CryptoKeysResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
type remoteMultiAgentHandler struct {
sdk *Client
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
ma *agpl.MultiAgent
cancel func()
}
func (a *remoteMultiAgentHandler) respLoop() {
{
defer a.cancel()
for {
resp, err := a.protocol.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) {
a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err))
return
}
a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err))
return
}
err = a.ma.Enqueue(resp)
if err != nil {
a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err))
continue
}
}
}
}
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error {
return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}})
}
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
}
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
}
func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) {
err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err))
}
_ = a.protocol.CloseSend()
}
@@ -1,37 +1,21 @@
package wsproxysdk_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/enterprise/tailnet"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@@ -152,142 +136,6 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
})
}
func TestDialCoordinator(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID = uuid.UUID{33}
proxyID = uuid.UUID{44}
mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t))
coord agpl.Coordinator = mCoord
r = chi.NewRouter()
srv = httptest.NewServer(r)
)
defer cancel()
defer srv.Close()
coordPtr := atomic.Pointer[agpl.Coordinator]{}
coordPtr.Store(&coord)
cSrv, err := tailnet.NewClientService(agpl.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: agpl.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
// buffer the channels here, so we don't need to read and write in goroutines to
// avoid blocking
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}).
Times(1).
Return(reqs, resps)
serveMACErr := make(chan error, 1)
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
version := r.URL.Query().Get("version")
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
return
}
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID)
serveMACErr <- err
})
u, err := url.Parse(srv.URL)
require.NoError(t, err)
client := wsproxysdk.New(u)
client.SDKClient.SetLogger(logger)
peerID := uuid.UUID{55}
peerNodeKey, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
peerDiscoKey, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: peerID[:],
Node: &proto.Node{
Id: 55,
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: peerNodeKey,
Disco: string(peerDiscoKey),
PreferredDerp: 0,
DerpLatency: map[string]float64{
"0": 1.0,
},
DerpForcedWebsocket: map[int32]string{},
Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
Endpoints: []string{"192.168.1.1:18842"},
},
}}}
rma, err := client.DialCoordinator(ctx)
require.NoError(t, err)
// Subscribe
{
require.NoError(t, rma.SubscribeAgent(agentID))
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
}
// Read updated agent node
{
resps <- expected
resp, ok := rma.NextUpdate(ctx)
assert.True(t, ok)
updates := resp.GetPeerUpdates()
assert.Len(t, updates, 1)
eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode())
assert.NoError(t, err)
assert.True(t, eq)
}
// UpdateSelf
{
require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode()))
req := testutil.RequireRecvCtx(ctx, t, reqs)
eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode())
require.NoError(t, err)
require.True(t, eq)
}
// Unsubscribe
{
require.NoError(t, rma.UnsubscribeAgent(agentID))
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId())
}
// Close
{
require.NoError(t, rma.Close())
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for req close")
case _, ok := <-reqs:
require.False(t, ok, "didn't close requests")
}
require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr))
}
})
}
type ResponseRecorder struct {
rw *httptest.ResponseRecorder
wasWritten atomic.Bool
+86 -57
View File
@@ -112,33 +112,41 @@ type ControlProtocolDialer interface {
Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error)
}
// basicCoordinationController handles the basic coordination operations common to all types of
// BasicCoordinationController handles the basic coordination operations common to all types of
// tailnet consumers:
//
// 1. sending local node updates to the Coordinator
// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn)
// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates.
type basicCoordinationController struct {
logger slog.Logger
coordinatee Coordinatee
sendAcks bool
//
// It is designed to be used on its own, or composed into more advanced CoordinationControllers.
type BasicCoordinationController struct {
Logger slog.Logger
Coordinatee Coordinatee
SendAcks bool
}
func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter {
b := &basicCoordination{
logger: c.logger,
// New satisfies the method on the CoordinationController interface
func (c *BasicCoordinationController) New(client CoordinatorClient) CloserWaiter {
return c.NewCoordination(client)
}
// NewCoordination creates a BasicCoordination
func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient) *BasicCoordination {
b := &BasicCoordination{
logger: c.Logger,
errChan: make(chan error, 1),
coordinatee: c.coordinatee,
client: client,
coordinatee: c.Coordinatee,
Client: client,
respLoopDone: make(chan struct{}),
sendAcks: c.sendAcks,
sendAcks: c.SendAcks,
}
c.coordinatee.SetNodeCallback(func(node *Node) {
c.Coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
b.sendErr(err)
b.SendErr(err)
return
}
b.Lock()
@@ -147,9 +155,9 @@ func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter
b.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
err = b.Client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
b.sendErr(xerrors.Errorf("write: %w", err))
b.SendErr(xerrors.Errorf("write: %w", err))
}
})
go b.respLoop()
@@ -157,18 +165,27 @@ func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter
return b
}
type basicCoordination struct {
// BasicCoordination handles:
//
// 1. Sending local node updates to the control plane
// 2. Reading remote updates from the control plane and programming them into the Coordinatee.
//
// It does *not* handle adding any Tunnels, but these can be handled by composing
// BasicCoordinationController with a more advanced controller.
type BasicCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
client CoordinatorClient
Client CoordinatorClient
respLoopDone chan struct{}
sendAcks bool
}
func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
// Close the coordination gracefully. If the context expires before the remote API server has hung
// up on us, we forcibly close the Client connection.
func (c *BasicCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
@@ -188,13 +205,13 @@ func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.client.Close()
protoErr := c.Client.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
err := c.Client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
@@ -203,20 +220,24 @@ func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
return nil
}
func (c *basicCoordination) Wait() <-chan error {
// Wait for the Coordination to complete
func (c *BasicCoordination) Wait() <-chan error {
return c.errChan
}
func (c *basicCoordination) sendErr(err error) {
// SendErr is not part of the CloserWaiter interface, and is intended to be called internally, or
// by Controllers that use BasicCoordinationController in composition. It triggers Wait() to
// report the error if an error has not already been reported.
func (c *BasicCoordination) SendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *basicCoordination) respLoop() {
func (c *BasicCoordination) respLoop() {
defer func() {
cErr := c.client.Close()
cErr := c.Client.Close()
if cErr != nil {
c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr))
}
@@ -224,17 +245,17 @@ func (c *basicCoordination) respLoop() {
close(c.respLoopDone)
}()
for {
resp, err := c.client.Recv()
resp, err := c.Client.Recv()
if err != nil {
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
c.sendErr(xerrors.Errorf("read: %w", err))
c.SendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
c.sendErr(xerrors.Errorf("update peers: %w", err))
c.SendErr(xerrors.Errorf("update peers: %w", err))
return
}
@@ -253,12 +274,12 @@ func (c *basicCoordination) respLoop() {
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
}
if len(rfh) > 0 {
err := c.client.Send(&proto.CoordinateRequest{
err := c.Client.Send(&proto.CoordinateRequest{
ReadyForHandshake: rfh,
})
if err != nil {
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
c.sendErr(xerrors.Errorf("send: %w", err))
c.SendErr(xerrors.Errorf("send: %w", err))
return
}
}
@@ -267,7 +288,7 @@ func (c *basicCoordination) respLoop() {
}
type singleDestController struct {
*basicCoordinationController
*BasicCoordinationController
dest uuid.UUID
}
@@ -276,21 +297,20 @@ type singleDestController struct {
func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController {
coordinatee.SetTunnelDestination(dest)
return &singleDestController{
basicCoordinationController: &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: false,
BasicCoordinationController: &BasicCoordinationController{
Logger: logger,
Coordinatee: coordinatee,
SendAcks: false,
},
dest: dest,
}
}
func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
// nolint: forcetypeassert
b := c.basicCoordinationController.New(client).(*basicCoordination)
b := c.BasicCoordinationController.NewCoordination(client)
err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}})
if err != nil {
b.sendErr(err)
b.SendErr(err)
}
return b
}
@@ -298,10 +318,10 @@ func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never
// create tunnels and always send ReadyToHandshake acknowledgements.
func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController {
return &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: true,
return &BasicCoordinationController{
Logger: logger,
Coordinatee: coordinatee,
SendAcks: true,
}
}
@@ -360,11 +380,11 @@ func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) {
// local Coordinator. (The typical alternative is a DRPC-based client.)
func NewInMemoryCoordinatorClient(
logger slog.Logger,
clientID, agentID uuid.UUID,
clientID uuid.UUID,
auth CoordinateeAuth,
coordinator Coordinator,
) CoordinatorClient {
logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID))
auth := ClientCoordinateeAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID))
c := &inMemoryCoordClient{logger: logger}
c.ctx, c.cancel = context.WithCancel(context.Background())
@@ -742,9 +762,9 @@ func (c *Controller) Run(ctx context.Context) {
c.logger.Error(c.ctx, "failed to dial tailnet v2+ API", errF)
continue
}
c.logger.Debug(c.ctx, "obtained tailnet API v2+ client")
c.logger.Info(c.ctx, "obtained tailnet API v2+ client")
c.runControllersOnce(tailnetClients)
c.logger.Debug(c.ctx, "tailnet API v2+ connection lost")
c.logger.Info(c.ctx, "tailnet API v2+ connection lost")
}
}()
}
@@ -754,15 +774,20 @@ func (c *Controller) Run(ctx context.Context) {
// appropriate). We typically multiplex all RPCs over the same websocket, so we want them to share
// the same fate.
func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
defer func() {
closeErr := clients.Closer.Close()
if closeErr != nil &&
!xerrors.Is(closeErr, io.EOF) &&
!xerrors.Is(closeErr, context.Canceled) &&
!xerrors.Is(closeErr, context.DeadlineExceeded) {
c.logger.Error(c.ctx, "error closing DRPC connection", slog.Error(closeErr))
}
}()
// clients.Closer.Close should nominally be idempotent, but let's not press our luck
closeOnce := sync.Once{}
closeClients := func() {
closeOnce.Do(func() {
closeErr := clients.Closer.Close()
if closeErr != nil &&
!xerrors.Is(closeErr, io.EOF) &&
!xerrors.Is(closeErr, context.Canceled) &&
!xerrors.Is(closeErr, context.DeadlineExceeded) {
c.logger.Error(c.ctx, "error closing tailnet clients", slog.Error(closeErr))
}
})
}
defer closeClients()
if c.TelemetryCtrl != nil {
c.TelemetryCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
@@ -775,6 +800,11 @@ func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
go func() {
defer wg.Done()
c.coordinate(clients.Coordinator)
if c.ctx.Err() == nil {
// Main context is still active, but our coordination exited, due to some error.
// Close down all the rest of the clients so we'll exit and retry.
closeClients()
}
}()
}
if c.DERPCtrl != nil {
@@ -788,8 +818,7 @@ func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
_ = clients.Closer.Close()
// Note that derpMap() logs it own errors, we don't bother here.
closeClients()
}
}()
}
+9 -3
View File
@@ -42,11 +42,12 @@ func TestInMemoryCoordination(t *testing.T) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth).
Times(1).Return(reqs, resps)
ctrl := tailnet.NewSingleDestController(logger, fConn, agentID)
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, mCoord))
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord))
defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
@@ -724,7 +725,7 @@ func TestController_Disconnects(t *testing.T) {
peersLost := make(chan struct{})
fConn := &fakeTailnetConn{peersLostCh: peersLost}
uut := tailnet.NewController(logger.Named("tac"), dialer,
uut := tailnet.NewController(logger.Named("ctrl"), dialer,
// darwin can be slow sometimes.
tailnet.WithGracefulTimeout(5*time.Second))
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger.Named("coord_ctrl"), fConn)
@@ -746,6 +747,11 @@ func TestController_Disconnects(t *testing.T) {
// ...and then reconnect
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
// close the coordination call, which should cause a 2nd reconnection
close(call.Resps)
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
// canceling the context should trigger the disconnect message
cancel()
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
-77
View File
@@ -45,7 +45,6 @@ type CoordinatorV2 interface {
Node(id uuid.UUID) *Node
Close() error
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
ServeMultiAgent(id uuid.UUID) MultiAgentConn
}
// Node represents a node in the network.
@@ -174,39 +173,6 @@ func (c *coordinator) Coordinate(
return reqs, resps
}
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
return ServeMultiAgent(c, c.core.logger, id)
}
func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
logger = logger.With(slog.F("client_id", id)).Named("multiagent")
ctx, cancel := context.WithCancel(context.Background())
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{})
m := (&MultiAgent{
ID: id,
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return err
},
OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return err
},
OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error {
return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: node,
}})
},
OnRemove: func(_ Queue) {
_ = SendCtx(ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
cancel()
},
}).Init()
go qRespLoop(ctx, cancel, logger, m, resps)
return m
}
// core is an in-memory structure of peer mappings. Its methods may be called from multiple goroutines;
// it is protected by a mutex to ensure data stay consistent.
type core struct {
@@ -218,27 +184,6 @@ type core struct {
tunnels *tunnelStore
}
type QueueKind int
const (
QueueKindClient QueueKind = 1 + iota
QueueKindAgent
)
type Queue interface {
UniqueID() uuid.UUID
Kind() QueueKind
Enqueue(resp *proto.CoordinateResponse) error
Name() string
Stats() (start, lastWrite int64)
Overwrites() int64
// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
CoordinatorClose() error
Done() <-chan struct{}
Close() error
}
func newCore(logger slog.Logger) *core {
return &core{
logger: logger,
@@ -671,25 +616,3 @@ func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
return a, io.EOF
}
}
func qRespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
defer func() {
cErr := q.Close()
if cErr != nil {
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
}
cancel()
}()
for {
resp, err := RecvCtx(ctx, resps)
if err != nil {
logger.Debug(ctx, "qRespLoop done reading responses", slog.Error(err))
return
}
logger.Debug(ctx, "qRespLoop got response", slog.F("resp", resp))
err = q.Enqueue(resp)
if err != nil && !xerrors.Is(err, context.Canceled) {
logger.Error(ctx, "qRespLoop failed to enqueue v1 update", slog.Error(err))
}
}
}
-19
View File
@@ -12,7 +12,6 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
@@ -245,24 +244,6 @@ func TestCoordinator_Lost(t *testing.T) {
test.LostTest(ctx, t, coordinator)
}
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
err := coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
}
// TestCoordinatorPropogatedPeerContext tests that the context for a specific peer
// is propogated through to the `Authorize“ method of the coordinatee auth
func TestCoordinatorPropogatedPeerContext(t *testing.T) {
-168
View File
@@ -1,168 +0,0 @@
package tailnet
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet/proto"
)
type MultiAgentConn interface {
UpdateSelf(node *proto.Node) error
SubscribeAgent(agentID uuid.UUID) error
UnsubscribeAgent(agentID uuid.UUID) error
NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool)
Close() error
IsClosed() bool
}
type MultiAgent struct {
mu sync.RWMutex
ID uuid.UUID
OnSubscribe func(enq Queue, agent uuid.UUID) error
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
OnNodeUpdate func(id uuid.UUID, node *proto.Node) error
OnRemove func(enq Queue)
ctx context.Context
ctxCancel func()
closed bool
updates chan *proto.CoordinateResponse
closeOnce sync.Once
start int64
lastWrite int64
// Client nodes normally generate a unique id for each connection so
// overwrites are really not an issue, but is provided for compatibility.
overwrites int64
}
func (m *MultiAgent) Init() *MultiAgent {
m.updates = make(chan *proto.CoordinateResponse, 128)
m.start = time.Now().Unix()
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
return m
}
func (*MultiAgent) Kind() QueueKind {
return QueueKindClient
}
func (m *MultiAgent) UniqueID() uuid.UUID {
return m.ID
}
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
func (m *MultiAgent) UpdateSelf(node *proto.Node) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
return m.OnNodeUpdate(m.ID, node)
}
func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
err := m.OnSubscribe(m, agentID)
if err != nil {
return err
}
return nil
}
func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
return m.OnUnsubscribe(m, agentID)
}
func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) {
select {
case <-ctx.Done():
return nil, false
case resp, ok := <-m.updates:
return resp, ok
}
}
func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil
}
return m.enqueueLocked(resp)
}
func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error {
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
select {
case m.updates <- resp:
return nil
default:
return ErrWouldBlock
}
}
func (m *MultiAgent) Name() string {
return m.ID.String()
}
func (m *MultiAgent) Stats() (start int64, lastWrite int64) {
return m.start, atomic.LoadInt64(&m.lastWrite)
}
func (m *MultiAgent) Overwrites() int64 {
return m.overwrites
}
func (m *MultiAgent) IsClosed() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.closed
}
func (m *MultiAgent) CoordinatorClose() error {
m.mu.Lock()
if !m.closed {
m.closed = true
close(m.updates)
}
m.mu.Unlock()
return nil
}
func (m *MultiAgent) Done() <-chan struct{} {
return m.ctx.Done()
}
func (m *MultiAgent) Close() error {
_ = m.CoordinatorClose()
m.ctxCancel()
m.closeOnce.Do(func() { m.OnRemove(m) })
return nil
}
-14
View File
@@ -97,17 +97,3 @@ func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Ca
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1)
}
// ServeMultiAgent mocks base method.
func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeMultiAgent", arg0)
ret0, _ := ret[0].(tailnet.MultiAgentConn)
return ret0
}
// ServeMultiAgent indicates an expected call of ServeMultiAgent.
func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0)
}
-127
View File
@@ -1,127 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn)
//
// Generated by this command:
//
// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
reflect "reflect"
proto "github.com/coder/coder/v2/tailnet/proto"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockMultiAgentConn is a mock of MultiAgentConn interface.
type MockMultiAgentConn struct {
ctrl *gomock.Controller
recorder *MockMultiAgentConnMockRecorder
}
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
type MockMultiAgentConnMockRecorder struct {
mock *MockMultiAgentConn
}
// NewMockMultiAgentConn creates a new mock instance.
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
mock := &MockMultiAgentConn{ctrl: ctrl}
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockMultiAgentConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
}
// IsClosed mocks base method.
func (m *MockMultiAgentConn) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
}
// NextUpdate mocks base method.
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) (*proto.CoordinateResponse, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextUpdate", arg0)
ret0, _ := ret[0].(*proto.CoordinateResponse)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// NextUpdate indicates an expected call of NextUpdate.
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
}
// SubscribeAgent mocks base method.
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SubscribeAgent indicates an expected call of SubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
}
// UnsubscribeAgent mocks base method.
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
}
// UpdateSelf mocks base method.
func (m *MockMultiAgentConn) UpdateSelf(arg0 *proto.Node) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateSelf indicates an expected call of UpdateSelf.
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
}
-126
View File
@@ -9,11 +9,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/stun/stuntest"
@@ -25,10 +22,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
//go:generate mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee
@@ -159,123 +154,6 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
}
}
type TestMultiAgent struct {
t testing.TB
ID uuid.UUID
a tailnet.MultiAgentConn
nodeKey []byte
discoKey string
}
func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &TestMultiAgent{t: t, ID: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.ID)
return m
}
func (m *TestMultiAgent) SendNodeWithDERP(d int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: d,
})
require.NoError(m.t, err)
}
func (m *TestMultiAgent) Close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
m.t.Helper()
tkr := time.NewTicker(testutil.IntervalFast)
defer tkr.Stop()
for {
select {
case <-ctx.Done():
m.t.Fatal("timeout")
return // unhittable
case <-tkr.C:
if m.a.IsClosed() {
return
}
}
}
}
type FakeCoordinator struct {
CoordinateCalls chan *FakeCoordinate
}
@@ -292,10 +170,6 @@ func (*FakeCoordinator) Close() error {
panic("unimplemented")
}
func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
panic("unimplemented")
}
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
+12
View File
@@ -101,6 +101,18 @@ func (p *Peer) AddTunnel(other uuid.UUID) {
}
}
func (p *Peer) RemoveTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout removing tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) UpdateDERP(derp int32) {
p.t.Helper()
node := &proto.Node{PreferredDerp: derp}