fix(tailnet): Skip nodes without DERP, avoid use of RemoveAllPeers (#6320)

* fix(tailnet): Skip nodes without DERP, avoid use of RemoveAllPeers
This commit is contained in:
Mathias Fredriksson
2023-02-24 18:16:29 +02:00
committed by GitHub
parent a414de9e81
commit 677721e4a1
10 changed files with 90 additions and 45 deletions
+3 -1
View File
@@ -601,7 +601,9 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
} }
defer coordinator.Close() defer coordinator.Close()
a.logger.Info(ctx, "connected to coordination server") a.logger.Info(ctx, "connected to coordination server")
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, network.UpdateNodes) sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
return network.UpdateNodes(nodes, false)
})
network.SetNodeCallback(sendNodes) network.SetNodeCallback(sendNodes)
select { select {
case <-ctx.Done(): case <-ctx.Done():
+12 -3
View File
@@ -1179,12 +1179,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
coordinator.ServeClient(serverConn, uuid.New(), agentID) coordinator.ServeClient(serverConn, uuid.New(), agentID)
}() }()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node) return conn.UpdateNodes(node, false)
}) })
conn.SetNodeCallback(sendNode) conn.SetNodeCallback(sendNode)
return &codersdk.WorkspaceAgentConn{ agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn, Conn: conn,
}, c, statsCh, fs }
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, statsCh, fs
} }
var dialTestPayload = []byte("dean-was-here123") var dialTestPayload = []byte("dean-was-here123")
+2 -1
View File
@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
"github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/clitest"
@@ -28,7 +29,7 @@ func TestSpeedtest(t *testing.T) {
agentClient.SetSessionToken(agentToken) agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{ agentCloser := agent.New(agent.Options{
Client: agentClient, Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"), Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
}) })
defer agentCloser.Close() defer agentCloser.Close()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
+2
View File
@@ -24,6 +24,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent" gosshagent "golang.org/x/crypto/ssh/agent"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
@@ -47,6 +48,7 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
} }
} }
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
client.Logger = slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug)
user := coderdtest.CreateFirstUser(t, client) user := coderdtest.CreateFirstUser(t, client)
agentToken := uuid.NewString() agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
+3 -3
View File
@@ -80,16 +80,16 @@ func TestDERP(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
w2Ready := make(chan struct{}, 1) w2Ready := make(chan struct{})
w2ReadyOnce := sync.Once{} w2ReadyOnce := sync.Once{}
w1.SetNodeCallback(func(node *tailnet.Node) { w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node}) w2.UpdateNodes([]*tailnet.Node{node}, false)
w2ReadyOnce.Do(func() { w2ReadyOnce.Do(func() {
close(w2Ready) close(w2Ready)
}) })
}) })
w2.SetNodeCallback(func(node *tailnet.Node) { w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node}) w1.UpdateNodes([]*tailnet.Node{node}, false)
}) })
conn := make(chan struct{}) conn := make(chan struct{})
+16 -15
View File
@@ -404,6 +404,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
} }
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
ctx := r.Context()
clientConn, serverConn := net.Pipe() clientConn, serverConn := net.Pipe()
derpMap := api.DERPMap.Clone() derpMap := api.DERPMap.Clone()
@@ -453,32 +454,32 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
} }
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
err := conn.RemoveAllPeers() err = conn.UpdateNodes(node, true)
if err != nil {
return xerrors.Errorf("remove all peers: %w", err)
}
err = conn.UpdateNodes(node)
if err != nil { if err != nil {
return xerrors.Errorf("update nodes: %w", err) return xerrors.Errorf("update nodes: %w", err)
} }
return nil return nil
}) })
conn.SetNodeCallback(sendNodes) conn.SetNodeCallback(sendNodes)
go func() { agentConn := &codersdk.WorkspaceAgentConn{
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
_ = conn.Close()
}
}()
return &codersdk.WorkspaceAgentConn{
Conn: conn, Conn: conn,
CloseFunc: func() { CloseFunc: func() {
_ = clientConn.Close() _ = clientConn.Close()
_ = serverConn.Close() _ = serverConn.Close()
}, },
}, nil }
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
_ = agentConn.Close()
}
}()
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
return nil, xerrors.Errorf("agent not reachable")
}
return agentConn, nil
} }
// @Summary Get connection info for workspace agent // @Summary Get connection info for workspace agent
+11 -2
View File
@@ -191,12 +191,21 @@ func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Durati
}) })
go coordinator.ServeClient(serverConn, uuid.New(), agentID) go coordinator.ServeClient(serverConn, uuid.New(), agentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node) return conn.UpdateNodes(node, false)
}) })
conn.SetNodeCallback(sendNode) conn.SetNodeCallback(sendNode)
return &codersdk.WorkspaceAgentConn{ agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn, Conn: conn,
} }
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn
} }
type client struct { type client struct {
+15 -7
View File
@@ -100,7 +100,7 @@ type DialWorkspaceAgentOptions struct {
BlockEndpoints bool BlockEndpoints bool
} }
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*WorkspaceAgentConn, error) { func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (agentConn *WorkspaceAgentConn, err error) {
if options == nil { if options == nil {
options = &DialWorkspaceAgentOptions{} options = &DialWorkspaceAgentOptions{}
} }
@@ -128,6 +128,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
if err != nil { if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err) return nil, xerrors.Errorf("create tailnet: %w", err)
} }
defer func() {
if err != nil {
_ = conn.Close()
}
}()
coordinateURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID)) coordinateURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", agentID))
if err != nil { if err != nil {
@@ -145,7 +150,12 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
Jar: jar, Jar: jar,
Transport: c.HTTPClient.Transport, Transport: c.HTTPClient.Transport,
} }
ctx, cancelFunc := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
closed := make(chan struct{}) closed := make(chan struct{})
first := make(chan error) first := make(chan error)
go func() { go func() {
@@ -175,7 +185,7 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
continue continue
} }
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
return conn.UpdateNodes(node) return conn.UpdateNodes(node, false)
}) })
conn.SetNodeCallback(sendNode) conn.SetNodeCallback(sendNode)
options.Logger.Debug(ctx, "serving coordinator") options.Logger.Debug(ctx, "serving coordinator")
@@ -194,15 +204,13 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
}() }()
err = <-first err = <-first
if err != nil { if err != nil {
cancelFunc()
_ = conn.Close()
return nil, err return nil, err
} }
agentConn := &WorkspaceAgentConn{ agentConn = &WorkspaceAgentConn{
Conn: conn, Conn: conn,
CloseFunc: func() { CloseFunc: func() {
cancelFunc() cancel()
<-closed <-closed
}, },
} }
+22 -9
View File
@@ -130,7 +130,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
}() }()
dialer := &tsdial.Dialer{ dialer := &tsdial.Dialer{
Logf: Logger(options.Logger), Logf: Logger(options.Logger.Named("tsdial")),
} }
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("wgengine")), wgengine.Config{ wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("wgengine")), wgengine.Config{
LinkMonitor: wireguardMonitor, LinkMonitor: wireguardMonitor,
@@ -179,6 +179,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
wireguardEngine = wgengine.NewWatchdog(wireguardEngine) wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap) wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap netMapCopy := *netMap
options.Logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
wireguardEngine.SetNetworkMap(&netMapCopy) wireguardEngine.SetNetworkMap(&netMapCopy)
localIPSet := netipx.IPSetBuilder{} localIPSet := netipx.IPSetBuilder{}
@@ -329,9 +330,11 @@ func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap)) c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
c.netMap.DERPMap = derpMap
c.wireguardEngine.SetNetworkMap(c.netMap)
c.wireguardEngine.SetDERPMap(derpMap) c.wireguardEngine.SetDERPMap(derpMap)
c.netMap.DERPMap = derpMap
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy)
} }
func (c *Conn) RemoveAllPeers() error { func (c *Conn) RemoveAllPeers() error {
@@ -341,6 +344,7 @@ func (c *Conn) RemoveAllPeers() error {
c.netMap.Peers = []*tailcfg.Node{} c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
netMapCopy := *c.netMap netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy) c.wireguardEngine.SetNetworkMap(&netMapCopy)
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil { if err != nil {
@@ -360,11 +364,18 @@ func (c *Conn) RemoveAllPeers() error {
} }
// UpdateNodes connects with a set of peers. This can be constantly updated, // UpdateNodes connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary. // and peers will continually be reconnected as necessary. If replacePeers is
func (c *Conn) UpdateNodes(nodes []*Node) error { // true, all peers will be removed before adding the new ones.
//
//nolint:revive // Complains about replacePeers.
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
status := c.Status() status := c.Status()
if replacePeers {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
}
for _, peer := range c.netMap.Peers { for _, peer := range c.netMap.Peers {
peerStatus, ok := status.Peer[peer.Key] peerStatus, ok := status.Peer[peer.Key]
if !ok { if !ok {
@@ -384,6 +395,11 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
delete(c.peerMap, peer.ID) delete(c.peerMap, peer.ID)
} }
for _, node := range nodes { for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
continue
}
c.logger.Debug(context.Background(), "adding node", slog.F("node", node)) c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
peerStatus, ok := status.Peer[node.Key] peerStatus, ok := status.Peer[node.Key]
@@ -402,10 +418,6 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
// reason. TODO: @kylecarbs debug this! // reason. TODO: @kylecarbs debug this!
KeepAlive: ok && peerStatus.Active, KeepAlive: ok && peerStatus.Active,
} }
// If no preferred DERP is provided, don't set an IP!
if node.PreferredDERP == 0 {
peerNode.DERP = ""
}
if c.blockEndpoints { if c.blockEndpoints {
peerNode.Endpoints = nil peerNode.Endpoints = nil
} }
@@ -416,6 +428,7 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
} }
netMapCopy := *c.netMap netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map", slog.F("net_map", netMapCopy))
c.wireguardEngine.SetNetworkMap(&netMapCopy) c.wireguardEngine.SetNetworkMap(&netMapCopy)
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "") cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil { if err != nil {
+4 -4
View File
@@ -55,12 +55,12 @@ func TestTailnet(t *testing.T) {
_ = w2.Close() _ = w2.Close()
}) })
w1.SetNodeCallback(func(node *tailnet.Node) { w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node}) err := w2.UpdateNodes([]*tailnet.Node{node}, false)
require.NoError(t, err) assert.NoError(t, err)
}) })
w2.SetNodeCallback(func(node *tailnet.Node) { w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}) err := w1.UpdateNodes([]*tailnet.Node{node}, false)
require.NoError(t, err) assert.NoError(t, err)
}) })
require.True(t, w2.AwaitReachable(context.Background(), w1IP)) require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{}) conn := make(chan struct{})