mirror of
https://github.com/coder/coder.git
synced 2026-06-05 22:18:20 +00:00
e5d911462f
This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID.
293 lines
8.4 KiB
Go
293 lines
8.4 KiB
Go
package agenttest
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/exp/maps"
|
|
"golang.org/x/exp/slices"
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcmux"
|
|
"storj.io/drpc/drpcserver"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
agentproto "github.com/coder/coder/v2/agent/proto"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
|
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
const statsInterval = 500 * time.Millisecond
|
|
|
|
func NewClient(t testing.TB,
|
|
logger slog.Logger,
|
|
agentID uuid.UUID,
|
|
manifest agentsdk.Manifest,
|
|
statsChan chan *agentproto.Stats,
|
|
coordinator tailnet.Coordinator,
|
|
) *Client {
|
|
if manifest.AgentID == uuid.Nil {
|
|
manifest.AgentID = agentID
|
|
}
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coordinator)
|
|
mux := drpcmux.New()
|
|
derpMapUpdates := make(chan *tailcfg.DERPMap)
|
|
drpcService := &tailnet.DRPCService{
|
|
CoordPtr: &coordPtr,
|
|
Logger: logger.Named("tailnetsvc"),
|
|
DerpMapUpdateFrequency: time.Microsecond,
|
|
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
|
|
}
|
|
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
|
require.NoError(t, err)
|
|
mp, err := agentsdk.ProtoFromManifest(manifest)
|
|
require.NoError(t, err)
|
|
fakeAAPI := NewFakeAgentAPI(t, logger, mp, statsChan)
|
|
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
|
|
require.NoError(t, err)
|
|
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
|
Log: func(err error) {
|
|
if xerrors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
|
},
|
|
})
|
|
return &Client{
|
|
t: t,
|
|
logger: logger.Named("client"),
|
|
agentID: agentID,
|
|
coordinator: coordinator,
|
|
server: server,
|
|
fakeAgentAPI: fakeAAPI,
|
|
derpMapUpdates: derpMapUpdates,
|
|
}
|
|
}
|
|
|
|
type Client struct {
|
|
t testing.TB
|
|
logger slog.Logger
|
|
agentID uuid.UUID
|
|
coordinator tailnet.Coordinator
|
|
server *drpcserver.Server
|
|
fakeAgentAPI *FakeAgentAPI
|
|
LastWorkspaceAgent func()
|
|
|
|
mu sync.Mutex // Protects following.
|
|
logs []agentsdk.Log
|
|
derpMapUpdates chan *tailcfg.DERPMap
|
|
derpMapOnce sync.Once
|
|
}
|
|
|
|
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
|
|
|
|
func (c *Client) Close() {
|
|
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
|
|
}
|
|
|
|
func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
|
|
conn, lis := drpcsdk.MemTransportPipe()
|
|
c.LastWorkspaceAgent = func() {
|
|
_ = conn.Close()
|
|
_ = lis.Close()
|
|
}
|
|
c.t.Cleanup(c.LastWorkspaceAgent)
|
|
serveCtx, cancel := context.WithCancel(ctx)
|
|
c.t.Cleanup(cancel)
|
|
streamID := tailnet.StreamID{
|
|
Name: "agenttest",
|
|
ID: c.agentID,
|
|
Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID},
|
|
}
|
|
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
|
|
go func() {
|
|
_ = c.server.Serve(serveCtx, lis)
|
|
}()
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
|
return c.fakeAgentAPI.GetLifecycleStates()
|
|
}
|
|
|
|
func (c *Client) GetStartup() <-chan *agentproto.Startup {
|
|
return c.fakeAgentAPI.startupCh
|
|
}
|
|
|
|
func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
|
|
return c.fakeAgentAPI.GetMetadata()
|
|
}
|
|
|
|
func (c *Client) GetStartupLogs() []agentsdk.Log {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.logs
|
|
}
|
|
|
|
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
|
|
c.fakeAgentAPI.SetServiceBannerFunc(f)
|
|
}
|
|
|
|
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
|
|
timer := time.NewTimer(testutil.WaitShort)
|
|
defer timer.Stop()
|
|
select {
|
|
case c.derpMapUpdates <- update:
|
|
case <-timer.C:
|
|
return xerrors.New("timeout waiting to push derp map update")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
|
|
c.fakeAgentAPI.SetLogsChannel(ch)
|
|
}
|
|
|
|
type FakeAgentAPI struct {
|
|
sync.Mutex
|
|
t testing.TB
|
|
logger slog.Logger
|
|
|
|
manifest *agentproto.Manifest
|
|
startupCh chan *agentproto.Startup
|
|
statsCh chan *agentproto.Stats
|
|
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
|
|
logsCh chan<- *agentproto.BatchCreateLogsRequest
|
|
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
|
metadata map[string]agentsdk.Metadata
|
|
|
|
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
|
|
}
|
|
|
|
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
|
return f.manifest, nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.getServiceBannerFunc = fn
|
|
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
|
|
}
|
|
|
|
func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
if f.getServiceBannerFunc == nil {
|
|
return &agentproto.ServiceBanner{}, nil
|
|
}
|
|
sb, err := f.getServiceBannerFunc()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return agentsdk.ProtoFromServiceBanner(sb), nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
|
|
f.logger.Debug(ctx, "update stats called", slog.F("req", req))
|
|
// empty request is sent to get the interval; but our tests don't want empty stats requests
|
|
if req.Stats != nil {
|
|
f.statsCh <- req.Stats
|
|
}
|
|
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
return slices.Clone(f.lifecycleStates)
|
|
}
|
|
|
|
func (f *FakeAgentAPI) UpdateLifecycle(_ context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
s, err := agentsdk.LifecycleStateFromProto(req.GetLifecycle().GetState())
|
|
if assert.NoError(f.t, err) {
|
|
f.lifecycleStates = append(f.lifecycleStates, s)
|
|
}
|
|
return req.GetLifecycle(), nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
|
|
f.logger.Debug(ctx, "batch update app health", slog.F("req", req))
|
|
f.appHealthCh <- req
|
|
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) AppHealthCh() <-chan *agentproto.BatchUpdateAppHealthRequest {
|
|
return f.appHealthCh
|
|
}
|
|
|
|
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
|
|
f.startupCh <- req.GetStartup()
|
|
return req.GetStartup(), nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) GetMetadata() map[string]agentsdk.Metadata {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
return maps.Clone(f.metadata)
|
|
}
|
|
|
|
func (f *FakeAgentAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
if f.metadata == nil {
|
|
f.metadata = make(map[string]agentsdk.Metadata)
|
|
}
|
|
for _, md := range req.Metadata {
|
|
smd := agentsdk.MetadataFromProto(md)
|
|
f.metadata[md.Key] = smd
|
|
f.logger.Debug(ctx, "post metadata", slog.F("key", md.Key), slog.F("md", md))
|
|
}
|
|
return &agentproto.BatchUpdateMetadataResponse{}, nil
|
|
}
|
|
|
|
func (f *FakeAgentAPI) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.logsCh = ch
|
|
}
|
|
|
|
func (f *FakeAgentAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
|
|
f.logger.Info(ctx, "batch create logs called", slog.F("req", req))
|
|
f.Lock()
|
|
ch := f.logsCh
|
|
f.Unlock()
|
|
if ch != nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case ch <- req:
|
|
// ok
|
|
}
|
|
}
|
|
return &agentproto.BatchCreateLogsResponse{}, nil
|
|
}
|
|
|
|
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
|
|
return &FakeAgentAPI{
|
|
t: t,
|
|
logger: logger.Named("FakeAgentAPI"),
|
|
manifest: manifest,
|
|
statsCh: statsCh,
|
|
startupCh: make(chan *agentproto.Startup, 100),
|
|
appHealthCh: make(chan *agentproto.BatchUpdateAppHealthRequest, 100),
|
|
}
|
|
}
|