mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: refactor ServerTailnet to use tailnet.Controllers (#15408)
chore of #14729 Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing. This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
This commit is contained in:
@@ -507,7 +507,6 @@ gen: \
|
||||
examples/examples.gen.json \
|
||||
tailnet/tailnettest/coordinatormock.go \
|
||||
tailnet/tailnettest/coordinateemock.go \
|
||||
tailnet/tailnettest/multiagentmock.go \
|
||||
coderd/database/pubsub/psmock/psmock.go
|
||||
.PHONY: gen
|
||||
|
||||
@@ -537,7 +536,6 @@ gen/mark-fresh:
|
||||
examples/examples.gen.json \
|
||||
tailnet/tailnettest/coordinatormock.go \
|
||||
tailnet/tailnettest/coordinateemock.go \
|
||||
tailnet/tailnettest/multiagentmock.go \
|
||||
coderd/database/pubsub/psmock/psmock.go \
|
||||
"
|
||||
|
||||
@@ -570,7 +568,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
|
||||
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
|
||||
go generate ./coderd/database/pubsub/psmock
|
||||
|
||||
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/multiagentmock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go tailnet/multiagent.go
|
||||
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
|
||||
go generate ./tailnet/tailnettest/
|
||||
|
||||
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
|
||||
+4
-2
@@ -1919,7 +1919,8 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
||||
t.Cleanup(testCtxCancel)
|
||||
clientID := uuid.New()
|
||||
ctrl := tailnet.NewSingleDestController(logger, conn, agentID)
|
||||
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, coordinator))
|
||||
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
|
||||
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator))
|
||||
t.Cleanup(func() {
|
||||
t.Logf("closing coordination %s", name)
|
||||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||
@@ -2408,8 +2409,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
|
||||
t.Cleanup(testCtxCancel)
|
||||
clientID := uuid.New()
|
||||
ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID)
|
||||
auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID}
|
||||
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
|
||||
logger, clientID, metadata.AgentID, coordinator))
|
||||
logger, clientID, auth, coordinator))
|
||||
t.Cleanup(func() {
|
||||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||
defer ccancel()
|
||||
|
||||
+7
-4
@@ -627,14 +627,17 @@ func New(options *Options) *API {
|
||||
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
dialer := &InmemTailnetDialer{
|
||||
CoordPtr: &api.TailnetCoordinator,
|
||||
DERPFn: api.DERPMap,
|
||||
Logger: options.Logger,
|
||||
ClientID: uuid.New(),
|
||||
}
|
||||
stn, err := NewServerTailnet(api.ctx,
|
||||
options.Logger,
|
||||
options.DERPServer,
|
||||
api.DERPMap,
|
||||
dialer,
|
||||
options.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
||||
func(context.Context) (tailnet.MultiAgentConn, error) {
|
||||
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
|
||||
},
|
||||
options.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
api.TracerProvider,
|
||||
)
|
||||
|
||||
+302
-231
@@ -30,7 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/site"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
var tailnetTransport *http.Transport
|
||||
@@ -53,9 +53,8 @@ func NewServerTailnet(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
derpServer *derp.Server,
|
||||
derpMapFn func() *tailcfg.DERPMap,
|
||||
dialer tailnet.ControlProtocolDialer,
|
||||
derpForceWebSockets bool,
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
|
||||
blockEndpoints bool,
|
||||
traceProvider trace.TracerProvider,
|
||||
) (*ServerTailnet, error) {
|
||||
@@ -91,46 +90,26 @@ func NewServerTailnet(
|
||||
})
|
||||
}
|
||||
|
||||
bgRoutines := &sync.WaitGroup{}
|
||||
originalDerpMap := derpMapFn()
|
||||
tracer := traceProvider.Tracer(tracing.TracerName)
|
||||
|
||||
controller := tailnet.NewController(logger, dialer)
|
||||
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
|
||||
// there is an embedded relay, we use the local in-memory dialer.
|
||||
conn.SetDERPMap(originalDerpMap)
|
||||
bgRoutines.Add(1)
|
||||
go func() {
|
||||
defer bgRoutines.Done()
|
||||
defer logger.Debug(ctx, "polling DERPMap exited")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-serverCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
newDerpMap := derpMapFn()
|
||||
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
|
||||
conn.SetDERPMap(newDerpMap)
|
||||
originalDerpMap = newDerpMap
|
||||
}
|
||||
}
|
||||
}()
|
||||
controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn)
|
||||
coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn)
|
||||
controller.CoordCtrl = coordCtrl
|
||||
// TODO: support controller.TelemetryCtrl
|
||||
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
bgRoutines: bgRoutines,
|
||||
logger: logger,
|
||||
tracer: traceProvider.Tracer(tracing.TracerName),
|
||||
conn: conn,
|
||||
coordinatee: conn,
|
||||
getMultiAgent: getMultiAgent,
|
||||
agentConnectionTimes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
conn: conn,
|
||||
coordinatee: conn,
|
||||
controller: controller,
|
||||
coordCtrl: coordCtrl,
|
||||
transport: tailnetTransport.Clone(),
|
||||
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "servertailnet",
|
||||
@@ -146,7 +125,7 @@ func NewServerTailnet(
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
// These options are mostly just picked at random, and they can likely be
|
||||
// fine tuned further. Generally, users are running applications in dev mode
|
||||
// fine-tuned further. Generally, users are running applications in dev mode
|
||||
// which can generate hundreds of requests per page load, so we increased
|
||||
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
|
||||
// conns.
|
||||
@@ -164,23 +143,7 @@ func NewServerTailnet(
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
|
||||
agentConn, err := getMultiAgent(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get initial multi agent: %w", err)
|
||||
}
|
||||
tn.agentConn.Store(&agentConn)
|
||||
// registering the callback also triggers send of the initial node
|
||||
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
|
||||
|
||||
tn.bgRoutines.Add(2)
|
||||
go func() {
|
||||
defer tn.bgRoutines.Done()
|
||||
tn.watchAgentUpdates()
|
||||
}()
|
||||
go func() {
|
||||
defer tn.bgRoutines.Done()
|
||||
tn.expireOldAgents()
|
||||
}()
|
||||
tn.controller.Run(tn.ctx)
|
||||
return tn, nil
|
||||
}
|
||||
|
||||
@@ -190,18 +153,6 @@ func (s *ServerTailnet) Conn() *tailnet.Conn {
|
||||
return s.conn
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if err != nil {
|
||||
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
return
|
||||
}
|
||||
err = s.getAgentConn().UpdateSelf(pn)
|
||||
if err != nil {
|
||||
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
|
||||
s.connsPerAgent.Describe(descs)
|
||||
s.totalConns.Describe(descs)
|
||||
@@ -212,125 +163,9 @@ func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
|
||||
s.totalConns.Collect(metrics)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) expireOldAgents() {
|
||||
defer s.logger.Debug(s.ctx, "stopped expiring old agents")
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.doExpireOldAgents(cutoff)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
|
||||
// TODO: add some attrs to this.
|
||||
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
deletedCount := 0
|
||||
|
||||
s.nodesMu.Lock()
|
||||
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentConnectionTimes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
err := agentConn.UnsubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
continue
|
||||
}
|
||||
deletedCount++
|
||||
delete(s.agentConnectionTimes, agentID)
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
|
||||
slog.F("deleted", deletedCount),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
defer s.logger.Debug(s.ctx, "stopped watching agent updates")
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
resp, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
|
||||
s.coordinatee.SetAllPeersLost()
|
||||
s.reinitCoordinator()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
if xerrors.Is(err, tailnet.ErrConnClosed) {
|
||||
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
|
||||
return
|
||||
}
|
||||
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
|
||||
return *s.agentConn.Load()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) reinitCoordinator() {
|
||||
start := time.Now()
|
||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
|
||||
s.nodesMu.Lock()
|
||||
agentConn, err := s.getMultiAgent(s.ctx)
|
||||
if err != nil {
|
||||
s.nodesMu.Unlock()
|
||||
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
s.agentConn.Store(&agentConn)
|
||||
// reset the Node callback, which triggers the conn to send the node immediately, and also
|
||||
// register for updates
|
||||
s.coordinatee.SetNodeCallback(s.nodeCallback)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentConnectionTimes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
|
||||
slog.F("agents", len(s.agentConnectionTimes)),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
s.nodesMu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type ServerTailnet struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
bgRoutines *sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
@@ -340,15 +175,8 @@ type ServerTailnet struct {
|
||||
conn *tailnet.Conn
|
||||
coordinatee tailnet.Coordinatee
|
||||
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
nodesMu sync.Mutex
|
||||
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
|
||||
// keep a connection to. It contains the last time the agent was connected
|
||||
// to.
|
||||
agentConnectionTimes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
controller *tailnet.Controller
|
||||
coordCtrl *MultiAgentController
|
||||
|
||||
transport *http.Transport
|
||||
|
||||
@@ -446,38 +274,6 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentConnectionTimes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
err := s.getAgentConn().SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentConnectionTimes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
|
||||
id := uuid.New()
|
||||
s.nodesMu.Lock()
|
||||
s.agentTickets[agentID][id] = struct{}{}
|
||||
s.nodesMu.Unlock()
|
||||
|
||||
return func() {
|
||||
s.nodesMu.Lock()
|
||||
delete(s.agentTickets[agentID], id)
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
|
||||
var (
|
||||
conn *workspacesdk.AgentConn
|
||||
@@ -485,11 +281,11 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*work
|
||||
)
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
err := s.ensureAgent(agentID)
|
||||
err := s.coordCtrl.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
ret = s.acquireTicket(agentID)
|
||||
ret = s.coordCtrl.acquireTicket(agentID)
|
||||
|
||||
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
|
||||
AgentID: agentID,
|
||||
@@ -548,7 +344,8 @@ func (s *ServerTailnet) Close() error {
|
||||
s.cancel()
|
||||
_ = s.conn.Close()
|
||||
s.transport.CloseIdleConnections()
|
||||
s.bgRoutines.Wait()
|
||||
s.coordCtrl.Close()
|
||||
<-s.controller.Closed()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -566,3 +363,277 @@ func (c *instrumentedConn) Close() error {
|
||||
})
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace
|
||||
// agents. It keeps track of connection times to the agents, and removes them on a timer if they
|
||||
// have no active connections and haven't been used in a while.
|
||||
type MultiAgentController struct {
|
||||
*tailnet.BasicCoordinationController
|
||||
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
|
||||
mu sync.Mutex
|
||||
// connectionTimes is a map of agents the server wants to keep a connection to. It
|
||||
// contains the last time the agent was connected to.
|
||||
connectionTimes map[uuid.UUID]time.Time
|
||||
// tickets is a map of destinations to a set of connection tickets, representing open
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter {
|
||||
b := m.BasicCoordinationController.NewCoordination(client)
|
||||
// resync all destinations
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.coordination = b
|
||||
for agentID := range m.connectionTimes {
|
||||
err := client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID),
|
||||
slog.Error(err))
|
||||
b.SendErr(err)
|
||||
_ = client.Close()
|
||||
m.coordination = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) {
|
||||
id := uuid.New()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.tickets[agentID][id] = struct{}{}
|
||||
|
||||
return func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.tickets[agentID], id)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) expireOldAgents(ctx context.Context) {
|
||||
defer close(m.expireOldAgentsDone)
|
||||
defer m.logger.Debug(context.Background(), "stopped expiring old agents")
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
m.doExpireOldAgents(ctx, cutoff)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) {
|
||||
// TODO: add some attrs to this.
|
||||
ctx, span := m.tracer.Start(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
deletedCount := 0
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes)))
|
||||
for agentID, lastConnection := range m.connectionTimes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 {
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err))
|
||||
// close the client because we do not want to do a graceful disconnect by
|
||||
// closing the coordination.
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
// Here we continue deleting any inactive agents: there is no point in
|
||||
// re-establishing tunnels to expired agents when we eventually reconnect.
|
||||
}
|
||||
}
|
||||
deletedCount++
|
||||
delete(m.connectionTimes, agentID)
|
||||
}
|
||||
}
|
||||
m.logger.Debug(ctx, "pruned inactive agents",
|
||||
slog.F("deleted", deletedCount),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) Close() {
|
||||
m.cancel()
|
||||
<-m.expireOldAgentsDone
|
||||
}
|
||||
|
||||
func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController {
|
||||
m := &MultiAgentController{
|
||||
BasicCoordinationController: &tailnet.BasicCoordinationController{
|
||||
Logger: logger,
|
||||
Coordinatee: coordinatee,
|
||||
SendAcks: false, // we are a client, connecting to multiple agents
|
||||
},
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
connectionTimes: make(map[uuid.UUID]time.Time),
|
||||
tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
||||
expireOldAgentsDone: make(chan struct{}),
|
||||
}
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
go m.expireOldAgents(ctx)
|
||||
return m
|
||||
}
|
||||
|
||||
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
|
||||
// service running in the same memory space.
|
||||
type InmemTailnetDialer struct {
|
||||
CoordPtr *atomic.Pointer[tailnet.Coordinator]
|
||||
DERPFn func() *tailcfg.DERPMap
|
||||
Logger slog.Logger
|
||||
ClientID uuid.UUID
|
||||
}
|
||||
|
||||
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
||||
coord := a.CoordPtr.Load()
|
||||
if coord == nil {
|
||||
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
|
||||
}
|
||||
coordClient := tailnet.NewInMemoryCoordinatorClient(
|
||||
a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord)
|
||||
derpClient := newPollingDERPClient(a.DERPFn, a.Logger)
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: closeAll{coord: coordClient, derp: derpClient},
|
||||
Coordinator: coordClient,
|
||||
DERP: derpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a := &pollingDERPClient{
|
||||
fn: derpFn,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
ch: make(chan *tailcfg.DERPMap),
|
||||
loopDone: make(chan struct{}),
|
||||
}
|
||||
go a.pollDERP()
|
||||
return a
|
||||
}
|
||||
|
||||
// pollingDERPClient is a DERP client that just calls a function on a polling
|
||||
// interval
|
||||
type pollingDERPClient struct {
|
||||
fn func() *tailcfg.DERPMap
|
||||
logger slog.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
loopDone chan struct{}
|
||||
lastDERPMap *tailcfg.DERPMap
|
||||
ch chan *tailcfg.DERPMap
|
||||
}
|
||||
|
||||
// Close the DERP client
|
||||
func (a *pollingDERPClient) Close() error {
|
||||
a.cancel()
|
||||
<-a.loopDone
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return nil, a.ctx.Err()
|
||||
case dm := <-a.ch:
|
||||
return dm, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *pollingDERPClient) pollDERP() {
|
||||
defer close(a.loopDone)
|
||||
defer a.logger.Debug(a.ctx, "polling DERPMap exited")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
newDerpMap := a.fn()
|
||||
if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
case a.ch <- newDerpMap:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type closeAll struct {
|
||||
coord tailnet.CoordinatorClient
|
||||
derp tailnet.DERPClient
|
||||
}
|
||||
|
||||
func (c closeAll) Close() error {
|
||||
cErr := c.coord.Close()
|
||||
dErr := c.derp.Close()
|
||||
if cErr != nil {
|
||||
return cErr
|
||||
}
|
||||
return dErr
|
||||
}
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestServerTailnet_Reconnect tests that ServerTailnet calls SetAllPeersLost on the Coordinatee
|
||||
// (tailnet.Conn in production) when it disconnects from the Coordinator (via MultiAgentConn) and
|
||||
// reconnects.
|
||||
func TestServerTailnet_Reconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
mMultiAgent0 := tailnettest.NewMockMultiAgentConn(ctrl)
|
||||
mMultiAgent1 := tailnettest.NewMockMultiAgentConn(ctrl)
|
||||
mac := make(chan tailnet.MultiAgentConn, 2)
|
||||
mac <- mMultiAgent0
|
||||
mac <- mMultiAgent1
|
||||
mCoord := tailnettest.NewMockCoordinatee(ctrl)
|
||||
|
||||
uut := &ServerTailnet{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
coordinatee: mCoord,
|
||||
getMultiAgent: func(ctx context.Context) (tailnet.MultiAgentConn, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case m := <-mac:
|
||||
return m, nil
|
||||
}
|
||||
},
|
||||
agentConn: atomic.Pointer[tailnet.MultiAgentConn]{},
|
||||
agentConnectionTimes: make(map[uuid.UUID]time.Time),
|
||||
}
|
||||
// reinit the Coordinator once, to load mMultiAgent0
|
||||
mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1)
|
||||
uut.reinitCoordinator()
|
||||
|
||||
mMultiAgent0.EXPECT().NextUpdate(gomock.Any()).
|
||||
Times(1).
|
||||
Return(nil, false) // this indicates there are no more updates
|
||||
closed0 := mMultiAgent0.EXPECT().IsClosed().
|
||||
Times(1).
|
||||
Return(true) // this triggers reconnect
|
||||
setLost := mCoord.EXPECT().SetAllPeersLost().Times(1).After(closed0)
|
||||
mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1).After(closed0)
|
||||
mMultiAgent1.EXPECT().NextUpdate(gomock.Any()).
|
||||
Times(1).
|
||||
After(setLost).
|
||||
Return(nil, false)
|
||||
mMultiAgent1.EXPECT().IsClosed().
|
||||
Times(1).
|
||||
Return(false) // this causes us to exit and not reconnect
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
uut.watchAgentUpdates()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
testutil.RequireRecvCtx(ctx, t, done)
|
||||
}
|
||||
@@ -399,6 +399,8 @@ func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...tailnettest.DER
|
||||
t.Cleanup(func() {
|
||||
_ = coord.Close()
|
||||
})
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
|
||||
agents := []agentWithID{}
|
||||
|
||||
@@ -430,13 +432,18 @@ func setupServerTailnetAgent(t *testing.T, agentNum int, opts ...tailnettest.DER
|
||||
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
|
||||
}
|
||||
|
||||
dialer := &coderd.InmemTailnetDialer{
|
||||
CoordPtr: &coordPtr,
|
||||
DERPFn: func() *tailcfg.DERPMap { return derpMap },
|
||||
Logger: logger,
|
||||
ClientID: uuid.UUID{5},
|
||||
}
|
||||
serverTailnet, err := coderd.NewServerTailnet(
|
||||
context.Background(),
|
||||
logger,
|
||||
derpServer,
|
||||
func() *tailcfg.DERPMap { return derpMap },
|
||||
dialer,
|
||||
false,
|
||||
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
|
||||
!derpMap.HasSTUN(),
|
||||
trace.NewNoopTracerProvider(),
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -39,19 +38,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
||||
defer agent1.Close(ctx)
|
||||
agent1.UpdateDERP(5)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
ma1.Close()
|
||||
ma1.Disconnect()
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
@@ -72,13 +71,13 @@ func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
err = coord1.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ma1.RequireEventuallyClosed(ctx)
|
||||
ma1.AssertEventuallyResponsesClosed()
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
|
||||
@@ -106,20 +105,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
||||
defer agent1.Close(ctx)
|
||||
agent1.UpdateDERP(5)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
ma1.RequireUnsubscribeAgent(agent1.ID)
|
||||
ma1.Close()
|
||||
ma1.RemoveTunnel(agent1.ID)
|
||||
ma1.Close(ctx)
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
@@ -151,35 +150,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
||||
defer agent1.Close(ctx)
|
||||
agent1.UpdateDERP(5)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord1, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
ma1.RequireUnsubscribeAgent(agent1.ID)
|
||||
ma1.RemoveTunnel(agent1.ID)
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
ma1.SendNodeWithDERP(9)
|
||||
ma1.UpdateDERP(9)
|
||||
agent1.AssertNeverHasDERPs(ctx, ma1.ID, 9)
|
||||
}()
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
agent1.UpdateDERP(8)
|
||||
ma1.RequireNeverHasDERPs(ctx, 8)
|
||||
ma1.AssertNeverHasDERPs(ctx, agent1.ID, 8)
|
||||
}()
|
||||
|
||||
ma1.Close()
|
||||
ma1.Disconnect()
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
@@ -216,19 +215,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
||||
defer agent1.Close(ctx)
|
||||
agent1.UpdateDERP(5)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord2, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
ma1.Close()
|
||||
ma1.Disconnect()
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
@@ -266,19 +265,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
|
||||
defer agent1.Close(ctx)
|
||||
agent1.UpdateDERP(5)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord2, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.Close()
|
||||
ma1.Disconnect()
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||
@@ -325,26 +324,26 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
||||
defer agent2.Close(ctx)
|
||||
agent2.UpdateDERP(6)
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
|
||||
defer ma1.Close()
|
||||
ma1 := agpltest.NewPeer(ctx, t, coord3, "client")
|
||||
defer ma1.Close(ctx)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent1.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||
ma1.AddTunnel(agent1.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 5)
|
||||
|
||||
agent1.UpdateDERP(1)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||
ma1.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||
|
||||
ma1.RequireSubscribeAgent(agent2.ID)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 6)
|
||||
ma1.AddTunnel(agent2.ID)
|
||||
ma1.AssertEventuallyHasDERP(agent2.ID, 6)
|
||||
|
||||
agent2.UpdateDERP(2)
|
||||
ma1.RequireEventuallyHasDERPs(ctx, 2)
|
||||
ma1.AssertEventuallyHasDERP(agent2.ID, 2)
|
||||
|
||||
ma1.SendNodeWithDERP(3)
|
||||
ma1.UpdateDERP(3)
|
||||
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
agent2.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||
|
||||
ma1.Close()
|
||||
ma1.Disconnect()
|
||||
agent1.UngracefulDisconnect(ctx)
|
||||
agent2.UngracefulDisconnect(ctx)
|
||||
|
||||
|
||||
@@ -165,10 +165,6 @@ func newPGCoordInternal(
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
||||
return agpl.ServeMultiAgent(c, c.logger, id)
|
||||
}
|
||||
|
||||
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
|
||||
// We're going to directly query the database, since we would only have the mapping stored locally if we had
|
||||
// a tunnel peer connected, which is not always the case.
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
@@ -144,7 +142,6 @@ type Server struct {
|
||||
replicaPingSingleflight singleflight.Group
|
||||
replicaErrMut sync.Mutex
|
||||
replicaErr string
|
||||
latestDERPMap atomic.Pointer[tailcfg.DERPMap]
|
||||
|
||||
// Used for graceful shutdown. Required for the dialer.
|
||||
ctx context.Context
|
||||
@@ -271,14 +268,15 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
||||
return nil, xerrors.Errorf("handle register: %w", err)
|
||||
}
|
||||
|
||||
dialer, err := s.SDKClient.TailnetDialer()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet dialer: %w", err)
|
||||
}
|
||||
agentProvider, err := coderd.NewServerTailnet(ctx,
|
||||
s.Logger,
|
||||
nil,
|
||||
func() *tailcfg.DERPMap {
|
||||
return s.latestDERPMap.Load()
|
||||
},
|
||||
dialer,
|
||||
regResp.DERPForceWebSockets,
|
||||
s.DialCoordinator,
|
||||
opts.BlockDirect,
|
||||
s.TracerProvider,
|
||||
)
|
||||
@@ -481,8 +479,6 @@ func (s *Server) handleRegister(res wsproxysdk.RegisterWorkspaceProxyResponse) e
|
||||
s.Logger.Debug(s.ctx, "setting DERP mesh sibling addresses", slog.F("addresses", addresses))
|
||||
s.derpMesh.SetAddresses(addresses, false)
|
||||
|
||||
s.latestDERPMap.Store(res.DERPMap)
|
||||
|
||||
go s.pingSiblingReplicas(res.SiblingReplicas)
|
||||
return nil
|
||||
}
|
||||
@@ -569,10 +565,6 @@ func (s *Server) handleRegisterFailure(err error) {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) {
|
||||
return s.SDKClient.DialCoordinator(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) buildInfo(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
||||
ExternalURL: buildinfo.ExternalURL(),
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -501,13 +500,11 @@ type CoordinateNodes struct {
|
||||
Nodes []*agpl.Node
|
||||
}
|
||||
|
||||
func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
logger := c.SDKClient.Logger().Named("multiagent")
|
||||
func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
|
||||
logger := c.SDKClient.Logger().Named("tailnet_dialer")
|
||||
|
||||
coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := coordinateURL.Query()
|
||||
@@ -520,59 +517,10 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
||||
}
|
||||
coordinateHeaders.Set(tokenHeader, c.SessionToken())
|
||||
|
||||
//nolint:bodyclose
|
||||
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{
|
||||
HTTPClient: c.SDKClient.HTTPClient,
|
||||
HTTPHeader: coordinateHeaders,
|
||||
})
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("dial coordinate websocket: %w", err)
|
||||
}
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
|
||||
client, err := agpl.NewDRPCClient(nc, logger)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "")
|
||||
return nil, xerrors.Errorf("failed to create DRPCClient: %w", err)
|
||||
}
|
||||
protocol, err := client.Coordinate(ctx)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "")
|
||||
return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err)
|
||||
}
|
||||
|
||||
rma := remoteMultiAgentHandler{
|
||||
sdk: c,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
ma := (&agpl.MultiAgent{
|
||||
ID: uuid.New(),
|
||||
OnSubscribe: rma.OnSubscribe,
|
||||
OnUnsubscribe: rma.OnUnsubscribe,
|
||||
OnNodeUpdate: rma.OnNodeUpdate,
|
||||
OnRemove: rma.OnRemove,
|
||||
}).Init()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = ma.Close()
|
||||
_ = client.DRPCConn().Close()
|
||||
<-client.DRPCConn().Closed()
|
||||
_ = conn.Close(websocket.StatusGoingAway, "closed")
|
||||
}()
|
||||
|
||||
rma.ma = ma
|
||||
go rma.respLoop()
|
||||
|
||||
return ma, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
type CryptoKeysResponse struct {
|
||||
@@ -595,55 +543,3 @@ func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeatu
|
||||
var resp CryptoKeysResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
type remoteMultiAgentHandler struct {
|
||||
sdk *Client
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
ma *agpl.MultiAgent
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) respLoop() {
|
||||
{
|
||||
defer a.cancel()
|
||||
for {
|
||||
resp, err := a.protocol.Recv()
|
||||
if err != nil {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = a.ma.Enqueue(resp)
|
||||
if err != nil {
|
||||
a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error {
|
||||
return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error {
|
||||
return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
|
||||
return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) {
|
||||
err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil {
|
||||
a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err))
|
||||
}
|
||||
_ = a.protocol.CloseSend()
|
||||
}
|
||||
|
||||
@@ -1,37 +1,21 @@
|
||||
package wsproxysdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -152,142 +136,6 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialCoordinator(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
agentID = uuid.UUID{33}
|
||||
proxyID = uuid.UUID{44}
|
||||
mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
coord agpl.Coordinator = mCoord
|
||||
r = chi.NewRouter()
|
||||
srv = httptest.NewServer(r)
|
||||
)
|
||||
defer cancel()
|
||||
defer srv.Close()
|
||||
|
||||
coordPtr := atomic.Pointer[agpl.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
cSrv, err := tailnet.NewClientService(agpl.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
||||
ResumeTokenProvider: agpl.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// buffer the channels here, so we don't need to read and write in goroutines to
|
||||
// avoid blocking
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}).
|
||||
Times(1).
|
||||
Return(reqs, resps)
|
||||
|
||||
serveMACErr := make(chan error, 1)
|
||||
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
version := r.URL.Query().Get("version")
|
||||
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
|
||||
return
|
||||
}
|
||||
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID)
|
||||
serveMACErr <- err
|
||||
})
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
client := wsproxysdk.New(u)
|
||||
client.SDKClient.SetLogger(logger)
|
||||
|
||||
peerID := uuid.UUID{55}
|
||||
peerNodeKey, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
peerDiscoKey, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: peerID[:],
|
||||
Node: &proto.Node{
|
||||
Id: 55,
|
||||
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
|
||||
Key: peerNodeKey,
|
||||
Disco: string(peerDiscoKey),
|
||||
PreferredDerp: 0,
|
||||
DerpLatency: map[string]float64{
|
||||
"0": 1.0,
|
||||
},
|
||||
DerpForcedWebsocket: map[int32]string{},
|
||||
Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
|
||||
AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
|
||||
Endpoints: []string{"192.168.1.1:18842"},
|
||||
},
|
||||
}}}
|
||||
|
||||
rma, err := client.DialCoordinator(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe
|
||||
{
|
||||
require.NoError(t, rma.SubscribeAgent(agentID))
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
||||
}
|
||||
// Read updated agent node
|
||||
{
|
||||
resps <- expected
|
||||
|
||||
resp, ok := rma.NextUpdate(ctx)
|
||||
assert.True(t, ok)
|
||||
updates := resp.GetPeerUpdates()
|
||||
assert.Len(t, updates, 1)
|
||||
eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, eq)
|
||||
}
|
||||
// UpdateSelf
|
||||
{
|
||||
require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode()))
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode())
|
||||
require.NoError(t, err)
|
||||
require.True(t, eq)
|
||||
}
|
||||
// Unsubscribe
|
||||
{
|
||||
require.NoError(t, rma.UnsubscribeAgent(agentID))
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId())
|
||||
}
|
||||
// Close
|
||||
{
|
||||
require.NoError(t, rma.Close())
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, req.Disconnect)
|
||||
close(resps)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for req close")
|
||||
case _, ok := <-reqs:
|
||||
require.False(t, ok, "didn't close requests")
|
||||
}
|
||||
require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type ResponseRecorder struct {
|
||||
rw *httptest.ResponseRecorder
|
||||
wasWritten atomic.Bool
|
||||
|
||||
+86
-57
@@ -112,33 +112,41 @@ type ControlProtocolDialer interface {
|
||||
Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error)
|
||||
}
|
||||
|
||||
// basicCoordinationController handles the basic coordination operations common to all types of
|
||||
// BasicCoordinationController handles the basic coordination operations common to all types of
|
||||
// tailnet consumers:
|
||||
//
|
||||
// 1. sending local node updates to the Coordinator
|
||||
// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn)
|
||||
// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates.
|
||||
type basicCoordinationController struct {
|
||||
logger slog.Logger
|
||||
coordinatee Coordinatee
|
||||
sendAcks bool
|
||||
//
|
||||
// It is designed to be used on its own, or composed into more advanced CoordinationControllers.
|
||||
type BasicCoordinationController struct {
|
||||
Logger slog.Logger
|
||||
Coordinatee Coordinatee
|
||||
SendAcks bool
|
||||
}
|
||||
|
||||
func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter {
|
||||
b := &basicCoordination{
|
||||
logger: c.logger,
|
||||
// New satisfies the method on the CoordinationController interface
|
||||
func (c *BasicCoordinationController) New(client CoordinatorClient) CloserWaiter {
|
||||
return c.NewCoordination(client)
|
||||
}
|
||||
|
||||
// NewCoordination creates a BasicCoordination
|
||||
func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient) *BasicCoordination {
|
||||
b := &BasicCoordination{
|
||||
logger: c.Logger,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: c.coordinatee,
|
||||
client: client,
|
||||
coordinatee: c.Coordinatee,
|
||||
Client: client,
|
||||
respLoopDone: make(chan struct{}),
|
||||
sendAcks: c.sendAcks,
|
||||
sendAcks: c.SendAcks,
|
||||
}
|
||||
|
||||
c.coordinatee.SetNodeCallback(func(node *Node) {
|
||||
c.Coordinatee.SetNodeCallback(func(node *Node) {
|
||||
pn, err := NodeToProto(node)
|
||||
if err != nil {
|
||||
b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
b.sendErr(err)
|
||||
b.SendErr(err)
|
||||
return
|
||||
}
|
||||
b.Lock()
|
||||
@@ -147,9 +155,9 @@ func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter
|
||||
b.logger.Debug(context.Background(), "ignored node update because coordination is closed")
|
||||
return
|
||||
}
|
||||
err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
|
||||
err = b.Client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
|
||||
if err != nil {
|
||||
b.sendErr(xerrors.Errorf("write: %w", err))
|
||||
b.SendErr(xerrors.Errorf("write: %w", err))
|
||||
}
|
||||
})
|
||||
go b.respLoop()
|
||||
@@ -157,18 +165,27 @@ func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter
|
||||
return b
|
||||
}
|
||||
|
||||
type basicCoordination struct {
|
||||
// BasicCoordination handles:
|
||||
//
|
||||
// 1. Sending local node updates to the control plane
|
||||
// 2. Reading remote updates from the control plane and programming them into the Coordinatee.
|
||||
//
|
||||
// It does *not* handle adding any Tunnels, but these can be handled by composing
|
||||
// BasicCoordinationController with a more advanced controller.
|
||||
type BasicCoordination struct {
|
||||
sync.Mutex
|
||||
closed bool
|
||||
errChan chan error
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
client CoordinatorClient
|
||||
Client CoordinatorClient
|
||||
respLoopDone chan struct{}
|
||||
sendAcks bool
|
||||
}
|
||||
|
||||
func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
|
||||
// Close the coordination gracefully. If the context expires before the remote API server has hung
|
||||
// up on us, we forcibly close the Client connection.
|
||||
func (c *BasicCoordination) Close(ctx context.Context) (retErr error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
@@ -188,13 +205,13 @@ func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
|
||||
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
|
||||
}
|
||||
// forcefully close the stream
|
||||
protoErr := c.client.Close()
|
||||
protoErr := c.Client.Close()
|
||||
<-c.respLoopDone
|
||||
if retErr == nil {
|
||||
retErr = protoErr
|
||||
}
|
||||
}()
|
||||
err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
err := c.Client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
|
||||
return xerrors.Errorf("send disconnect: %w", err)
|
||||
@@ -203,20 +220,24 @@ func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *basicCoordination) Wait() <-chan error {
|
||||
// Wait for the Coordination to complete
|
||||
func (c *BasicCoordination) Wait() <-chan error {
|
||||
return c.errChan
|
||||
}
|
||||
|
||||
func (c *basicCoordination) sendErr(err error) {
|
||||
// SendErr is not part of the CloserWaiter interface, and is intended to be called internally, or
|
||||
// by Controllers that use BasicCoordinationController in composition. It triggers Wait() to
|
||||
// report the error if an error has not already been reported.
|
||||
func (c *BasicCoordination) SendErr(err error) {
|
||||
select {
|
||||
case c.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *basicCoordination) respLoop() {
|
||||
func (c *BasicCoordination) respLoop() {
|
||||
defer func() {
|
||||
cErr := c.client.Close()
|
||||
cErr := c.Client.Close()
|
||||
if cErr != nil {
|
||||
c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr))
|
||||
}
|
||||
@@ -224,17 +245,17 @@ func (c *basicCoordination) respLoop() {
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for {
|
||||
resp, err := c.client.Recv()
|
||||
resp, err := c.Client.Recv()
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("read: %w", err))
|
||||
c.SendErr(xerrors.Errorf("read: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("update peers: %w", err))
|
||||
c.SendErr(xerrors.Errorf("update peers: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -253,12 +274,12 @@ func (c *basicCoordination) respLoop() {
|
||||
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
|
||||
}
|
||||
if len(rfh) > 0 {
|
||||
err := c.client.Send(&proto.CoordinateRequest{
|
||||
err := c.Client.Send(&proto.CoordinateRequest{
|
||||
ReadyForHandshake: rfh,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("send: %w", err))
|
||||
c.SendErr(xerrors.Errorf("send: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -267,7 +288,7 @@ func (c *basicCoordination) respLoop() {
|
||||
}
|
||||
|
||||
type singleDestController struct {
|
||||
*basicCoordinationController
|
||||
*BasicCoordinationController
|
||||
dest uuid.UUID
|
||||
}
|
||||
|
||||
@@ -276,21 +297,20 @@ type singleDestController struct {
|
||||
func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController {
|
||||
coordinatee.SetTunnelDestination(dest)
|
||||
return &singleDestController{
|
||||
basicCoordinationController: &basicCoordinationController{
|
||||
logger: logger,
|
||||
coordinatee: coordinatee,
|
||||
sendAcks: false,
|
||||
BasicCoordinationController: &BasicCoordinationController{
|
||||
Logger: logger,
|
||||
Coordinatee: coordinatee,
|
||||
SendAcks: false,
|
||||
},
|
||||
dest: dest,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
|
||||
// nolint: forcetypeassert
|
||||
b := c.basicCoordinationController.New(client).(*basicCoordination)
|
||||
b := c.BasicCoordinationController.NewCoordination(client)
|
||||
err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}})
|
||||
if err != nil {
|
||||
b.sendErr(err)
|
||||
b.SendErr(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -298,10 +318,10 @@ func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
|
||||
// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never
|
||||
// create tunnels and always send ReadyToHandshake acknowledgements.
|
||||
func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController {
|
||||
return &basicCoordinationController{
|
||||
logger: logger,
|
||||
coordinatee: coordinatee,
|
||||
sendAcks: true,
|
||||
return &BasicCoordinationController{
|
||||
Logger: logger,
|
||||
Coordinatee: coordinatee,
|
||||
SendAcks: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,11 +380,11 @@ func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) {
|
||||
// local Coordinator. (The typical alternative is a DRPC-based client.)
|
||||
func NewInMemoryCoordinatorClient(
|
||||
logger slog.Logger,
|
||||
clientID, agentID uuid.UUID,
|
||||
clientID uuid.UUID,
|
||||
auth CoordinateeAuth,
|
||||
coordinator Coordinator,
|
||||
) CoordinatorClient {
|
||||
logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID))
|
||||
auth := ClientCoordinateeAuth{AgentID: agentID}
|
||||
logger = logger.With(slog.F("client_id", clientID))
|
||||
c := &inMemoryCoordClient{logger: logger}
|
||||
c.ctx, c.cancel = context.WithCancel(context.Background())
|
||||
|
||||
@@ -742,9 +762,9 @@ func (c *Controller) Run(ctx context.Context) {
|
||||
c.logger.Error(c.ctx, "failed to dial tailnet v2+ API", errF)
|
||||
continue
|
||||
}
|
||||
c.logger.Debug(c.ctx, "obtained tailnet API v2+ client")
|
||||
c.logger.Info(c.ctx, "obtained tailnet API v2+ client")
|
||||
c.runControllersOnce(tailnetClients)
|
||||
c.logger.Debug(c.ctx, "tailnet API v2+ connection lost")
|
||||
c.logger.Info(c.ctx, "tailnet API v2+ connection lost")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -754,15 +774,20 @@ func (c *Controller) Run(ctx context.Context) {
|
||||
// appropriate). We typically multiplex all RPCs over the same websocket, so we want them to share
|
||||
// the same fate.
|
||||
func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
|
||||
defer func() {
|
||||
closeErr := clients.Closer.Close()
|
||||
if closeErr != nil &&
|
||||
!xerrors.Is(closeErr, io.EOF) &&
|
||||
!xerrors.Is(closeErr, context.Canceled) &&
|
||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||
c.logger.Error(c.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
||||
}
|
||||
}()
|
||||
// clients.Closer.Close should nominally be idempotent, but let's not press our luck
|
||||
closeOnce := sync.Once{}
|
||||
closeClients := func() {
|
||||
closeOnce.Do(func() {
|
||||
closeErr := clients.Closer.Close()
|
||||
if closeErr != nil &&
|
||||
!xerrors.Is(closeErr, io.EOF) &&
|
||||
!xerrors.Is(closeErr, context.Canceled) &&
|
||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||
c.logger.Error(c.ctx, "error closing tailnet clients", slog.Error(closeErr))
|
||||
}
|
||||
})
|
||||
}
|
||||
defer closeClients()
|
||||
|
||||
if c.TelemetryCtrl != nil {
|
||||
c.TelemetryCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
|
||||
@@ -775,6 +800,11 @@ func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.coordinate(clients.Coordinator)
|
||||
if c.ctx.Err() == nil {
|
||||
// Main context is still active, but our coordination exited, due to some error.
|
||||
// Close down all the rest of the clients so we'll exit and retry.
|
||||
closeClients()
|
||||
}
|
||||
}()
|
||||
}
|
||||
if c.DERPCtrl != nil {
|
||||
@@ -788,8 +818,7 @@ func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
|
||||
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
||||
// close the underlying connection. This will trigger a retry of the control plane in
|
||||
// run().
|
||||
_ = clients.Closer.Close()
|
||||
// Note that derpMap() logs it own errors, we don't bother here.
|
||||
closeClients()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -42,11 +42,12 @@ func TestInMemoryCoordination(t *testing.T) {
|
||||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
|
||||
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
ctrl := tailnet.NewSingleDestController(logger, fConn, agentID)
|
||||
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, mCoord))
|
||||
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord))
|
||||
defer uut.Close(ctx)
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
@@ -724,7 +725,7 @@ func TestController_Disconnects(t *testing.T) {
|
||||
peersLost := make(chan struct{})
|
||||
fConn := &fakeTailnetConn{peersLostCh: peersLost}
|
||||
|
||||
uut := tailnet.NewController(logger.Named("tac"), dialer,
|
||||
uut := tailnet.NewController(logger.Named("ctrl"), dialer,
|
||||
// darwin can be slow sometimes.
|
||||
tailnet.WithGracefulTimeout(5*time.Second))
|
||||
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger.Named("coord_ctrl"), fConn)
|
||||
@@ -746,6 +747,11 @@ func TestController_Disconnects(t *testing.T) {
|
||||
// ...and then reconnect
|
||||
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
||||
|
||||
// close the coordination call, which should cause a 2nd reconnection
|
||||
close(call.Resps)
|
||||
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
||||
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
||||
|
||||
// canceling the context should trigger the disconnect message
|
||||
cancel()
|
||||
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||
|
||||
@@ -45,7 +45,6 @@ type CoordinatorV2 interface {
|
||||
Node(id uuid.UUID) *Node
|
||||
Close() error
|
||||
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
||||
ServeMultiAgent(id uuid.UUID) MultiAgentConn
|
||||
}
|
||||
|
||||
// Node represents a node in the network.
|
||||
@@ -174,39 +173,6 @@ func (c *coordinator) Coordinate(
|
||||
return reqs, resps
|
||||
}
|
||||
|
||||
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
|
||||
return ServeMultiAgent(c, c.core.logger, id)
|
||||
}
|
||||
|
||||
func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
|
||||
logger = logger.With(slog.F("client_id", id)).Named("multiagent")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{})
|
||||
m := (&MultiAgent{
|
||||
ID: id,
|
||||
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
|
||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
|
||||
return err
|
||||
},
|
||||
OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
|
||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
|
||||
return err
|
||||
},
|
||||
OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error {
|
||||
return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
||||
Node: node,
|
||||
}})
|
||||
},
|
||||
OnRemove: func(_ Queue) {
|
||||
_ = SendCtx(ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
cancel()
|
||||
},
|
||||
}).Init()
|
||||
|
||||
go qRespLoop(ctx, cancel, logger, m, resps)
|
||||
return m
|
||||
}
|
||||
|
||||
// core is an in-memory structure of peer mappings. Its methods may be called from multiple goroutines;
|
||||
// it is protected by a mutex to ensure data stay consistent.
|
||||
type core struct {
|
||||
@@ -218,27 +184,6 @@ type core struct {
|
||||
tunnels *tunnelStore
|
||||
}
|
||||
|
||||
type QueueKind int
|
||||
|
||||
const (
|
||||
QueueKindClient QueueKind = 1 + iota
|
||||
QueueKindAgent
|
||||
)
|
||||
|
||||
type Queue interface {
|
||||
UniqueID() uuid.UUID
|
||||
Kind() QueueKind
|
||||
Enqueue(resp *proto.CoordinateResponse) error
|
||||
Name() string
|
||||
Stats() (start, lastWrite int64)
|
||||
Overwrites() int64
|
||||
// CoordinatorClose is used by the coordinator when closing a Queue. It
|
||||
// should skip removing itself from the coordinator.
|
||||
CoordinatorClose() error
|
||||
Done() <-chan struct{}
|
||||
Close() error
|
||||
}
|
||||
|
||||
func newCore(logger slog.Logger) *core {
|
||||
return &core{
|
||||
logger: logger,
|
||||
@@ -671,25 +616,3 @@ func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
|
||||
return a, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func qRespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
|
||||
defer func() {
|
||||
cErr := q.Close()
|
||||
if cErr != nil {
|
||||
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
for {
|
||||
resp, err := RecvCtx(ctx, resps)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "qRespLoop done reading responses", slog.Error(err))
|
||||
return
|
||||
}
|
||||
logger.Debug(ctx, "qRespLoop got response", slog.F("resp", resp))
|
||||
err = q.Enqueue(resp)
|
||||
if err != nil && !xerrors.Is(err, context.Canceled) {
|
||||
logger.Error(ctx, "qRespLoop failed to enqueue v1 update", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -245,24 +244,6 @@ func TestCoordinator_Lost(t *testing.T) {
|
||||
test.LostTest(ctx, t, coordinator)
|
||||
}
|
||||
|
||||
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
|
||||
defer coord1.Close()
|
||||
|
||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||
defer ma1.Close()
|
||||
|
||||
err := coord1.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ma1.RequireEventuallyClosed(ctx)
|
||||
}
|
||||
|
||||
// TestCoordinatorPropogatedPeerContext tests that the context for a specific peer
|
||||
// is propogated through to the `Authorize“ method of the coordinatee auth
|
||||
func TestCoordinatorPropogatedPeerContext(t *testing.T) {
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package tailnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
type MultiAgentConn interface {
|
||||
UpdateSelf(node *proto.Node) error
|
||||
SubscribeAgent(agentID uuid.UUID) error
|
||||
UnsubscribeAgent(agentID uuid.UUID) error
|
||||
NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool)
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type MultiAgent struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
ID uuid.UUID
|
||||
|
||||
OnSubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnNodeUpdate func(id uuid.UUID, node *proto.Node) error
|
||||
OnRemove func(enq Queue)
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
closed bool
|
||||
|
||||
updates chan *proto.CoordinateResponse
|
||||
closeOnce sync.Once
|
||||
start int64
|
||||
lastWrite int64
|
||||
// Client nodes normally generate a unique id for each connection so
|
||||
// overwrites are really not an issue, but is provided for compatibility.
|
||||
overwrites int64
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Init() *MultiAgent {
|
||||
m.updates = make(chan *proto.CoordinateResponse, 128)
|
||||
m.start = time.Now().Unix()
|
||||
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
|
||||
return m
|
||||
}
|
||||
|
||||
func (*MultiAgent) Kind() QueueKind {
|
||||
return QueueKindClient
|
||||
}
|
||||
|
||||
func (m *MultiAgent) UniqueID() uuid.UUID {
|
||||
return m.ID
|
||||
}
|
||||
|
||||
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
|
||||
|
||||
func (m *MultiAgent) UpdateSelf(node *proto.Node) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
return m.OnNodeUpdate(m.ID, node)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
err := m.OnSubscribe(m, agentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
return m.OnUnsubscribe(m, agentID)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false
|
||||
|
||||
case resp, ok := <-m.updates:
|
||||
return resp, ok
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.enqueueLocked(resp)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error {
|
||||
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
|
||||
|
||||
select {
|
||||
case m.updates <- resp:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Name() string {
|
||||
return m.ID.String()
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Stats() (start int64, lastWrite int64) {
|
||||
return m.start, atomic.LoadInt64(&m.lastWrite)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Overwrites() int64 {
|
||||
return m.overwrites
|
||||
}
|
||||
|
||||
func (m *MultiAgent) IsClosed() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.closed
|
||||
}
|
||||
|
||||
func (m *MultiAgent) CoordinatorClose() error {
|
||||
m.mu.Lock()
|
||||
if !m.closed {
|
||||
m.closed = true
|
||||
close(m.updates)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Done() <-chan struct{} {
|
||||
return m.ctx.Done()
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Close() error {
|
||||
_ = m.CoordinatorClose()
|
||||
m.ctxCancel()
|
||||
m.closeOnce.Do(func() { m.OnRemove(m) })
|
||||
return nil
|
||||
}
|
||||
@@ -97,17 +97,3 @@ func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Ca
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1)
|
||||
}
|
||||
|
||||
// ServeMultiAgent mocks base method.
|
||||
func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ServeMultiAgent", arg0)
|
||||
ret0, _ := ret[0].(tailnet.MultiAgentConn)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ServeMultiAgent indicates an expected call of ServeMultiAgent.
|
||||
func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0)
|
||||
}
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
|
||||
//
|
||||
|
||||
// Package tailnettest is a generated GoMock package.
|
||||
package tailnettest
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
proto "github.com/coder/coder/v2/tailnet/proto"
|
||||
uuid "github.com/google/uuid"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMultiAgentConn is a mock of MultiAgentConn interface.
|
||||
type MockMultiAgentConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMultiAgentConnMockRecorder
|
||||
}
|
||||
|
||||
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
|
||||
type MockMultiAgentConnMockRecorder struct {
|
||||
mock *MockMultiAgentConn
|
||||
}
|
||||
|
||||
// NewMockMultiAgentConn creates a new mock instance.
|
||||
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
|
||||
mock := &MockMultiAgentConn{ctrl: ctrl}
|
||||
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockMultiAgentConn) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
|
||||
}
|
||||
|
||||
// IsClosed mocks base method.
|
||||
func (m *MockMultiAgentConn) IsClosed() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClosed")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsClosed indicates an expected call of IsClosed.
|
||||
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
|
||||
}
|
||||
|
||||
// NextUpdate mocks base method.
|
||||
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) (*proto.CoordinateResponse, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "NextUpdate", arg0)
|
||||
ret0, _ := ret[0].(*proto.CoordinateResponse)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// NextUpdate indicates an expected call of NextUpdate.
|
||||
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
|
||||
}
|
||||
|
||||
// SubscribeAgent mocks base method.
|
||||
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SubscribeAgent indicates an expected call of SubscribeAgent.
|
||||
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
|
||||
}
|
||||
|
||||
// UnsubscribeAgent mocks base method.
|
||||
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
|
||||
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
|
||||
}
|
||||
|
||||
// UpdateSelf mocks base method.
|
||||
func (m *MockMultiAgentConn) UpdateSelf(arg0 *proto.Node) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateSelf indicates an expected call of UpdateSelf.
|
||||
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
|
||||
}
|
||||
@@ -9,11 +9,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/net/stun/stuntest"
|
||||
@@ -25,10 +22,8 @@ import (
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
|
||||
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
|
||||
//go:generate mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee
|
||||
|
||||
@@ -159,123 +154,6 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
|
||||
}
|
||||
}
|
||||
|
||||
type TestMultiAgent struct {
|
||||
t testing.TB
|
||||
ID uuid.UUID
|
||||
a tailnet.MultiAgentConn
|
||||
nodeKey []byte
|
||||
discoKey string
|
||||
}
|
||||
|
||||
func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent {
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dk, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
m := &TestMultiAgent{t: t, ID: uuid.New(), nodeKey: nk, discoKey: string(dk)}
|
||||
m.a = coord.ServeMultiAgent(m.ID)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) SendNodeWithDERP(d int32) {
|
||||
m.t.Helper()
|
||||
err := m.a.UpdateSelf(&proto.Node{
|
||||
Key: m.nodeKey,
|
||||
Disco: m.discoKey,
|
||||
PreferredDerp: d,
|
||||
})
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) Close() {
|
||||
m.t.Helper()
|
||||
err := m.a.Close()
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) {
|
||||
m.t.Helper()
|
||||
err := m.a.SubscribeAgent(id)
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) {
|
||||
m.t.Helper()
|
||||
err := m.a.UnsubscribeAgent(id)
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) {
|
||||
m.t.Helper()
|
||||
for {
|
||||
resp, ok := m.a.NextUpdate(ctx)
|
||||
require.True(m.t, ok)
|
||||
nodes, err := tailnet.OnlyNodeUpdates(resp)
|
||||
require.NoError(m.t, err)
|
||||
if len(nodes) != len(expected) {
|
||||
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
m.t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) {
|
||||
m.t.Helper()
|
||||
for {
|
||||
resp, ok := m.a.NextUpdate(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
nodes, err := tailnet.OnlyNodeUpdates(resp)
|
||||
require.NoError(m.t, err)
|
||||
if len(nodes) != len(expected) {
|
||||
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
m.t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
|
||||
m.t.Helper()
|
||||
tkr := time.NewTicker(testutil.IntervalFast)
|
||||
defer tkr.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.t.Fatal("timeout")
|
||||
return // unhittable
|
||||
case <-tkr.C:
|
||||
if m.a.IsClosed() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type FakeCoordinator struct {
|
||||
CoordinateCalls chan *FakeCoordinate
|
||||
}
|
||||
@@ -292,10 +170,6 @@ func (*FakeCoordinator) Close() error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
|
||||
@@ -101,6 +101,18 @@ func (p *Peer) AddTunnel(other uuid.UUID) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) RemoveTunnel(other uuid.UUID) {
|
||||
p.t.Helper()
|
||||
req := &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout removing tunnel for %s", p.name)
|
||||
return
|
||||
case p.reqs <- req:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) UpdateDERP(derp int32) {
|
||||
p.t.Helper()
|
||||
node := &proto.Node{PreferredDerp: derp}
|
||||
|
||||
Reference in New Issue
Block a user