From 278527cff4973cd32e402c271d0bcd1eb605bdc7 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 18 Jul 2023 12:17:11 +0100 Subject: [PATCH] feat(scaletest): add option to send traffic over SSH (#8521) - Refactors the metrics logic to avoid needing to pass in a whole prometheus registry - Adds an --ssh option to the workspace-traffic command to send SSH traffic Fixes #8242 --- cli/exp_scaletest.go | 25 +- cli/exp_scaletest_test.go | 1 + scaletest/workspacetraffic/config.go | 16 +- scaletest/workspacetraffic/conn.go | 123 ++++++ scaletest/workspacetraffic/countreadwriter.go | 71 +++ scaletest/workspacetraffic/metrics.go | 40 ++ scaletest/workspacetraffic/run.go | 186 ++++---- scaletest/workspacetraffic/run_test.go | 414 ++++++++++++------ 8 files changed, 619 insertions(+), 257 deletions(-) create mode 100644 scaletest/workspacetraffic/conn.go create mode 100644 scaletest/workspacetraffic/countreadwriter.go diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index 2fa3b2c37d..f74cc49eb6 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -848,6 +848,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { var ( tickInterval time.Duration bytesPerTick int64 + ssh bool scaletestPrometheusAddress string scaletestPrometheusWait time.Duration @@ -938,20 +939,19 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { // Setup our workspace agent connection. config := workspacetraffic.Config{ - AgentID: agentID, - AgentName: agentName, - BytesPerTick: bytesPerTick, - Duration: strategy.timeout, - TickInterval: tickInterval, - WorkspaceName: ws.Name, - WorkspaceOwner: ws.OwnerName, - Registry: reg, + AgentID: agentID, + BytesPerTick: bytesPerTick, + Duration: strategy.timeout, + TickInterval: tickInterval, + ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agentName), + WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agentName), + SSH: ssh, } if err := config.Validate(); err != nil { return xerrors.Errorf("validate config: %w", err) } - var runner harness.Runnable = workspacetraffic.NewRunner(client, config, metrics) + var runner harness.Runnable = workspacetraffic.NewRunner(client, config) if tracingEnabled { runner = &runnableTraceWrapper{ tracer: tracer, @@ -1002,6 +1002,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { Description: "How often to send traffic.", Value: clibase.DurationOf(&tickInterval), }, + { + Flag: "ssh", + Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_SSH", + Default: "", + Description: "Send traffic over SSH.", + Value: clibase.BoolOf(&ssh), + }, { Flag: "scaletest-prometheus-address", Env: "CODER_SCALETEST_PROMETHEUS_ADDRESS", diff --git a/cli/exp_scaletest_test.go b/cli/exp_scaletest_test.go index 3ff54dc9ab..940ba65eb9 100644 --- a/cli/exp_scaletest_test.go +++ b/cli/exp_scaletest_test.go @@ -69,6 +69,7 @@ func TestScaleTestWorkspaceTraffic(t *testing.T) { "--tick-interval", "100ms", "--scaletest-prometheus-address", "127.0.0.1:0", "--scaletest-prometheus-wait", "0s", + "--ssh", ) clitest.SetupConfig(t, client, root) var stdout, stderr bytes.Buffer diff --git a/scaletest/workspacetraffic/config.go b/scaletest/workspacetraffic/config.go index e6e830a417..cd4dca145f 100644 --- a/scaletest/workspacetraffic/config.go +++ b/scaletest/workspacetraffic/config.go @@ -4,7 +4,6 @@ import ( "time" "github.com/google/uuid" - "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" ) @@ -12,15 +11,6 @@ type Config struct { // AgentID is the workspace agent ID to which to connect. AgentID uuid.UUID `json:"agent_id"` - // AgentName is the name of the agent. Used for metrics. - AgentName string `json:"agent_name"` - - // WorkspaceName is the name of the workspace. Used for metrics. - WorkspaceName string `json:"workspace_name"` - - // WorkspaceOwner is the owner of the workspace. Used for metrics. - WorkspaceOwner string `json:"workspace_owner"` - // BytesPerTick is the number of bytes to send to the agent per tick. BytesPerTick int64 `json:"bytes_per_tick"` @@ -31,8 +21,10 @@ type Config struct { // send data to workspace agents). TickInterval time.Duration `json:"tick_interval"` - // Registry is a prometheus.Registerer for logging metrics - Registry prometheus.Registerer + ReadMetrics ConnMetrics `json:"-"` + WriteMetrics ConnMetrics `json:"-"` + + SSH bool `json:"ssh"` } func (c Config) Validate() error { diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go new file mode 100644 index 0000000000..167164c5ef --- /dev/null +++ b/scaletest/workspacetraffic/conn.go @@ -0,0 +1,123 @@ +package workspacetraffic + +import ( + "context" + "io" + "sync" + + "github.com/coder/coder/codersdk" + + "github.com/google/uuid" + "github.com/hashicorp/go-multierror" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" +) + +func connectPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect uuid.UUID) (*countReadWriteCloser, error) { + conn, err := client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + AgentID: agentID, + Reconnect: reconnect, + Height: 25, + Width: 80, + Command: "/bin/sh", + }) + if err != nil { + return nil, xerrors.Errorf("connect pty: %w", err) + } + + // Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd. + crw := countReadWriteCloser{ctx: ctx, rwc: conn} + return &crw, nil +} + +func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID) (*countReadWriteCloser, error) { + agentConn, err := client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{}) + if err != nil { + return nil, xerrors.Errorf("dial workspace agent: %w", err) + } + agentConn.AwaitReachable(ctx) + sshClient, err := agentConn.SSHClient(ctx) + if err != nil { + return nil, xerrors.Errorf("get ssh client: %w", err) + } + sshSession, err := sshClient.NewSession() + if err != nil { + _ = agentConn.Close() + return nil, xerrors.Errorf("new ssh session: %w", err) + } + wrappedConn := &wrappedSSHConn{ctx: ctx} + // Do some plumbing to hook up the wrappedConn + pr1, pw1 := io.Pipe() + wrappedConn.stdout = pr1 + sshSession.Stdout = pw1 + pr2, pw2 := io.Pipe() + sshSession.Stdin = pr2 + wrappedConn.stdin = pw2 + err = sshSession.RequestPty("xterm", 25, 80, gossh.TerminalModes{}) + if err != nil { + _ = pr1.Close() + _ = pr2.Close() + _ = pw1.Close() + _ = pw2.Close() + _ = sshSession.Close() + _ = agentConn.Close() + return nil, xerrors.Errorf("request pty: %w", err) + } + err = sshSession.Shell() + if err != nil { + _ = sshSession.Close() + _ = agentConn.Close() + return nil, xerrors.Errorf("shell: %w", err) + } + + closeFn := func() error { + var merr error + if err := sshSession.Close(); err != nil { + merr = multierror.Append(merr, err) + } + if err := agentConn.Close(); err != nil { + merr = multierror.Append(merr, err) + } + return merr + } + wrappedConn.close = closeFn + + crw := &countReadWriteCloser{ctx: ctx, rwc: wrappedConn} + return crw, nil +} + +// wrappedSSHConn wraps an ssh.Session to implement io.ReadWriteCloser. +type wrappedSSHConn struct { + ctx context.Context + stdout io.Reader + stdin io.Writer + closeOnce sync.Once + closeErr error + close func() error +} + +func (w *wrappedSSHConn) Close() error { + w.closeOnce.Do(func() { + _, _ = w.stdin.Write([]byte("exit\n")) + w.closeErr = w.close() + }) + return w.closeErr +} + +func (w *wrappedSSHConn) Read(p []byte) (n int, err error) { + select { + case <-w.ctx.Done(): + return 0, xerrors.Errorf("read: %w", w.ctx.Err()) + default: + return w.stdout.Read(p) + } +} + +func (w *wrappedSSHConn) Write(p []byte) (n int, err error) { + select { + case <-w.ctx.Done(): + return 0, xerrors.Errorf("write: %w", w.ctx.Err()) + default: + return w.stdin.Write(p) + } +} diff --git a/scaletest/workspacetraffic/countreadwriter.go b/scaletest/workspacetraffic/countreadwriter.go new file mode 100644 index 0000000000..5cb15ab175 --- /dev/null +++ b/scaletest/workspacetraffic/countreadwriter.go @@ -0,0 +1,71 @@ +package workspacetraffic + +import ( + "context" + "errors" + "io" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" +) + +// countReadWriteCloser wraps an io.ReadWriteCloser and counts the number of bytes read and written. +type countReadWriteCloser struct { + ctx context.Context + rwc io.ReadWriteCloser + readMetrics ConnMetrics + writeMetrics ConnMetrics +} + +func (w *countReadWriteCloser) Close() error { + return w.rwc.Close() +} + +func (w *countReadWriteCloser) Read(p []byte) (int, error) { + start := time.Now() + n, err := w.rwc.Read(p) + took := time.Since(start).Seconds() + if reportableErr(err) { + w.readMetrics.AddError(1) + } + w.readMetrics.ObserveLatency(took) + if n > 0 { + w.readMetrics.AddTotal(float64(n)) + } + return n, err +} + +func (w *countReadWriteCloser) Write(p []byte) (int, error) { + start := time.Now() + n, err := w.rwc.Write(p) + took := time.Since(start).Seconds() + if reportableErr(err) { + w.writeMetrics.AddError(1) + } + w.writeMetrics.ObserveLatency(took) + if n > 0 { + w.writeMetrics.AddTotal(float64(n)) + } + return n, err +} + +// some errors we want to report in metrics; others we want to ignore +// such as websocket.StatusNormalClosure or context.Canceled +func reportableErr(err error) bool { + if err == nil { + return false + } + if xerrors.Is(err, io.EOF) { + return false + } + if xerrors.Is(err, context.Canceled) { + return false + } + var wsErr websocket.CloseError + if errors.As(err, &wsErr) { + return wsErr.Code != websocket.StatusNormalClosure + } + return false +} diff --git a/scaletest/workspacetraffic/metrics.go b/scaletest/workspacetraffic/metrics.go index ce9fdc6caf..8b36f9b3df 100644 --- a/scaletest/workspacetraffic/metrics.go +++ b/scaletest/workspacetraffic/metrics.go @@ -54,3 +54,43 @@ func NewMetrics(reg prometheus.Registerer, labelNames ...string) *Metrics { reg.MustRegister(m.WriteLatencySeconds) return m } + +func (m *Metrics) ReadMetrics(lvs ...string) ConnMetrics { + return &connMetrics{ + addError: m.ReadErrorsTotal.WithLabelValues(lvs...).Add, + observeLatency: m.ReadLatencySeconds.WithLabelValues(lvs...).Observe, + addTotal: m.BytesReadTotal.WithLabelValues(lvs...).Add, + } +} + +func (m *Metrics) WriteMetrics(lvs ...string) ConnMetrics { + return &connMetrics{ + addError: m.WriteErrorsTotal.WithLabelValues(lvs...).Add, + observeLatency: m.WriteLatencySeconds.WithLabelValues(lvs...).Observe, + addTotal: m.BytesWrittenTotal.WithLabelValues(lvs...).Add, + } +} + +type ConnMetrics interface { + AddError(float64) + ObserveLatency(float64) + AddTotal(float64) +} + +type connMetrics struct { + addError func(float64) + observeLatency func(float64) + addTotal func(float64) +} + +func (c *connMetrics) AddError(f float64) { + c.addError(f) +} + +func (c *connMetrics) ObserveLatency(f float64) { + c.observeLatency(f) +} + +func (c *connMetrics) AddTotal(f float64) { + c.addTotal(f) +} diff --git a/scaletest/workspacetraffic/run.go b/scaletest/workspacetraffic/run.go index dff3a0b16d..db263ac909 100644 --- a/scaletest/workspacetraffic/run.go +++ b/scaletest/workspacetraffic/run.go @@ -3,7 +3,6 @@ package workspacetraffic import ( "context" "encoding/json" - "errors" "io" "time" @@ -22,9 +21,8 @@ import ( ) type Runner struct { - client *codersdk.Client - cfg Config - metrics *Metrics + client *codersdk.Client + cfg Config } var ( @@ -32,11 +30,11 @@ var ( _ harness.Cleanable = &Runner{} ) -func NewRunner(client *codersdk.Client, cfg Config, metrics *Metrics) *Runner { +// func NewRunner(client *codersdk.Client, cfg Config, metrics *Metrics) *Runner { +func NewRunner(client *codersdk.Client, cfg Config) *Runner { return &Runner{ - client: client, - cfg: cfg, - metrics: metrics, + client: client, + cfg: cfg, } } @@ -51,13 +49,12 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { // Initialize our metrics eagerly. This is mainly so that we can test for the // presence of a zero-valued metric as opposed to the absence of a metric. - lvs := []string{r.cfg.WorkspaceOwner, r.cfg.WorkspaceName, r.cfg.AgentName} - r.metrics.BytesReadTotal.WithLabelValues(lvs...).Add(0) - r.metrics.BytesWrittenTotal.WithLabelValues(lvs...).Add(0) - r.metrics.ReadErrorsTotal.WithLabelValues(lvs...).Add(0) - r.metrics.WriteErrorsTotal.WithLabelValues(lvs...).Add(0) - r.metrics.ReadLatencySeconds.WithLabelValues(lvs...).Observe(0) - r.metrics.WriteLatencySeconds.WithLabelValues(lvs...).Observe(0) + r.cfg.ReadMetrics.AddError(0) + r.cfg.ReadMetrics.AddTotal(0) + r.cfg.ReadMetrics.ObserveLatency(0) + r.cfg.WriteMetrics.AddError(0) + r.cfg.WriteMetrics.AddTotal(0) + r.cfg.WriteMetrics.ObserveLatency(0) var ( agentID = r.cfg.AgentID @@ -83,17 +80,25 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { defer cancel() logger.Debug(ctx, "connect to workspace agent", slog.F("agent_id", agentID)) - conn, err := r.client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ - AgentID: agentID, - Reconnect: reconnect, - Height: height, - Width: width, - Command: "/bin/sh", - }) - if err != nil { - logger.Error(ctx, "connect to workspace agent", slog.F("agent_id", agentID), slog.Error(err)) - return xerrors.Errorf("connect to workspace: %w", err) + var conn *countReadWriteCloser + var err error + if r.cfg.SSH { + logger.Info(ctx, "connecting to workspace agent", slog.F("agent_id", agentID), slog.F("method", "ssh")) + conn, err = connectSSH(ctx, r.client, agentID) + if err != nil { + logger.Error(ctx, "connect to workspace agent via ssh", slog.F("agent_id", agentID), slog.Error(err)) + return xerrors.Errorf("connect to workspace via ssh: %w", err) + } + } else { + logger.Info(ctx, "connecting to workspace agent", slog.F("agent_id", agentID), slog.F("method", "reconnectingpty")) + conn, err = connectPTY(ctx, r.client, agentID, reconnect) + if err != nil { + logger.Error(ctx, "connect to workspace agent via reconnectingpty", slog.F("agent_id", agentID), slog.Error(err)) + return xerrors.Errorf("connect to workspace via reconnectingpty: %w", err) + } } + conn.readMetrics = r.cfg.ReadMetrics + conn.writeMetrics = r.cfg.WriteMetrics go func() { <-deadlineCtx.Done() @@ -101,45 +106,63 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { _ = conn.Close() }() - // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd. - crw := countReadWriter{ReadWriter: conn, metrics: r.metrics, labels: lvs} - - // Create a ticker for sending data to the PTY. + // Create a ticker for sending data to the conn. tick := time.NewTicker(tickInterval) defer tick.Stop() - // Now we begin writing random data to the pty. + // Now we begin writing random data to the conn. rch := make(chan error, 1) wch := make(chan error, 1) go func() { <-deadlineCtx.Done() logger.Debug(ctx, "closing agent connection") - conn.Close() + _ = conn.Close() }() // Read forever in the background. go func() { - logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID)) - rch <- drain(&crw) - logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID)) - close(rch) + select { + case <-ctx.Done(): + logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID)) + default: + logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID)) + rch <- drain(conn) + close(rch) + } }() - // Write random data to the PTY every tick. + // To avoid hanging, close the conn when ctx is done + go func() { + <-ctx.Done() + _ = conn.Close() + }() + + // Write random data to the conn every tick. go func() { logger.Debug(ctx, "writing to agent", slog.F("agent_id", agentID)) - wch <- writeRandomData(&crw, bytesPerTick, tick.C) + if r.cfg.SSH { + wch <- writeRandomDataSSH(conn, bytesPerTick, tick.C) + } else { + wch <- writeRandomDataPTY(conn, bytesPerTick, tick.C) + } logger.Debug(ctx, "done writing to agent", slog.F("agent_id", agentID)) close(wch) }() // Write until the context is canceled. if wErr := <-wch; wErr != nil { - return xerrors.Errorf("write to pty: %w", wErr) + return xerrors.Errorf("write to agent: %w", wErr) } - if rErr := <-rch; rErr != nil { - return xerrors.Errorf("read from pty: %w", rErr) + + select { + case <-ctx.Done(): + logger.Warn(ctx, "timed out reading from agent", slog.F("agent_id", agentID)) + case rErr := <-rch: + logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID)) + if rErr != nil { + return xerrors.Errorf("read from agent: %w", rErr) + } } return nil @@ -153,6 +176,9 @@ func (*Runner) Cleanup(context.Context, string) error { // drain drains from src until it returns io.EOF or ctx times out. func drain(src io.Reader) error { if _, err := io.Copy(io.Discard, src); err != nil { + if xerrors.Is(err, context.Canceled) { + return nil + } if xerrors.Is(err, context.DeadlineExceeded) { return nil } @@ -164,15 +190,17 @@ func drain(src io.Reader) error { return nil } -func writeRandomData(dst io.Writer, size int64, tick <-chan time.Time) error { +func writeRandomDataPTY(dst io.Writer, size int64, tick <-chan time.Time) error { var ( enc = json.NewEncoder(dst) ptyReq = codersdk.ReconnectingPTYRequest{} ) for range tick { - payload := "#" + mustRandStr(size-1) - ptyReq.Data = payload + ptyReq.Data = mustRandomComment(size - 1) if err := enc.Encode(ptyReq); err != nil { + if xerrors.Is(err, context.Canceled) { + return nil + } if xerrors.Is(err, context.DeadlineExceeded) { return nil } @@ -185,40 +213,29 @@ func writeRandomData(dst io.Writer, size int64, tick <-chan time.Time) error { return nil } -// countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written. -type countReadWriter struct { - io.ReadWriter - metrics *Metrics - labels []string +func writeRandomDataSSH(dst io.Writer, size int64, tick <-chan time.Time) error { + for range tick { + payload := mustRandomComment(size - 1) + if _, err := dst.Write([]byte(payload + "\r\n")); err != nil { + if xerrors.Is(err, context.Canceled) { + return nil + } + if xerrors.Is(err, context.DeadlineExceeded) { + return nil + } + if xerrors.As(err, &websocket.CloseError{}) { + return nil + } + return err + } + } + return nil } -func (w *countReadWriter) Read(p []byte) (int, error) { - start := time.Now() - n, err := w.ReadWriter.Read(p) - if reportableErr(err) { - w.metrics.ReadErrorsTotal.WithLabelValues(w.labels...).Inc() - } - w.metrics.ReadLatencySeconds.WithLabelValues(w.labels...).Observe(time.Since(start).Seconds()) - if n > 0 { - w.metrics.BytesReadTotal.WithLabelValues(w.labels...).Add(float64(n)) - } - return n, err -} - -func (w *countReadWriter) Write(p []byte) (int, error) { - start := time.Now() - n, err := w.ReadWriter.Write(p) - if reportableErr(err) { - w.metrics.WriteErrorsTotal.WithLabelValues(w.labels...).Inc() - } - w.metrics.WriteLatencySeconds.WithLabelValues(w.labels...).Observe(time.Since(start).Seconds()) - if n > 0 { - w.metrics.BytesWrittenTotal.WithLabelValues(w.labels...).Add(float64(n)) - } - return n, err -} - -func mustRandStr(l int64) string { +// mustRandomComment returns a random string prefixed by a #. +// This allows us to send data both to and from a workspace agent +// while placing minimal load upon the workspace itself. +func mustRandomComment(l int64) string { if l < 1 { l = 1 } @@ -226,21 +243,6 @@ func mustRandStr(l int64) string { if err != nil { panic(err) } - return randStr -} - -// some errors we want to report in metrics; others we want to ignore -// such as websocket.StatusNormalClosure or context.Canceled -func reportableErr(err error) bool { - if err == nil { - return false - } - if xerrors.Is(err, context.Canceled) { - return false - } - var wsErr websocket.CloseError - if errors.As(err, &wsErr) { - return wsErr.Code != websocket.StatusNormalClosure - } - return false + // THIS IS A LOAD-BEARING OCTOTHORPE. DO NOT REMOVE. + return "#" + randStr } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index c070a906be..acea009cea 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -2,7 +2,9 @@ package workspacetraffic_test import ( "context" + "runtime" "strings" + "sync" "testing" "time" @@ -16,163 +18,287 @@ import ( "github.com/coder/coder/testutil" "github.com/google/uuid" - "github.com/prometheus/client_golang/prometheus" - dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRun(t *testing.T) { t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("Test not supported on windows.") + } + if testutil.RaceEnabled() { + t.Skip("Race detector enabled, skipping time-sensitive test.") + } - // We need to stand up an in-memory coderd and run a fake workspace. - var ( - client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - firstUser = coderdtest.CreateFirstUser(t, client) - authToken = uuid.NewString() - agentName = "agent" - version = coderdtest.CreateTemplateVersion(t, client, firstUser.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionPlan: echo.ProvisionComplete, - ProvisionApply: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "example", - Type: "aws_instance", - Agents: []*proto.Agent{{ - // Agent ID gets generated no matter what we say ¯\_(ツ)_/¯ - Name: agentName, - Auth: &proto.Agent_Token{ - Token: authToken, - }, - Apps: []*proto.App{}, + t.Run("PTY", func(t *testing.T) { + t.Parallel() + // We need to stand up an in-memory coderd and run a fake workspace. + var ( + client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + firstUser = coderdtest.CreateFirstUser(t, client) + authToken = uuid.NewString() + agentName = "agent" + version = coderdtest.CreateTemplateVersion(t, client, firstUser.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.ProvisionComplete, + ProvisionApply: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + // Agent ID gets generated no matter what we say ¯\_(ツ)_/¯ + Name: agentName, + Auth: &proto.Agent_Token{ + Token: authToken, + }, + Apps: []*proto.App{}, + }}, }}, - }}, + }, }, - }, - }}, - }) - template = coderdtest.CreateTemplate(t, client, firstUser.OrganizationID, version.ID) - _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - // In order to be picked up as a scaletest workspace, the workspace must be named specifically - ws = coderdtest.CreateWorkspace(t, client, firstUser.OrganizationID, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { - cwr.Name = "scaletest-test" - }) - _ = coderdtest.AwaitWorkspaceBuildJob(t, client, ws.LatestBuild.ID) - ) + }}, + }) + template = coderdtest.CreateTemplate(t, client, firstUser.OrganizationID, version.ID) + _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + // In order to be picked up as a scaletest workspace, the workspace must be named specifically + ws = coderdtest.CreateWorkspace(t, client, firstUser.OrganizationID, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = "scaletest-test" + }) + _ = coderdtest.AwaitWorkspaceBuildJob(t, client, ws.LatestBuild.ID) + ) - // We also need a running agent to run this test. - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) - agentCloser := agent.New(agent.Options{ - Client: agentClient, + // We also need a running agent to run this test. + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(authToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + t.Cleanup(func() { + _ = agentCloser.Close() + }) + + // Make sure the agent is connected before we go any further. + resources := coderdtest.AwaitWorkspaceAgents(t, client, ws.ID) + var agentID uuid.UUID + for _, res := range resources { + for _, agt := range res.Agents { + agentID = agt.ID + } + } + require.NotEqual(t, uuid.Nil, agentID, "did not expect agentID to be nil") + + // Now we can start the runner. + var ( + bytesPerTick = 1024 + tickInterval = 1000 * time.Millisecond + cancelAfter = 1500 * time.Millisecond + fudgeWrite = 12 // The ReconnectingPTY payload incurs some overhead + readMetrics = &testMetrics{} + writeMetrics = &testMetrics{} + ) + runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ + AgentID: agentID, + BytesPerTick: int64(bytesPerTick), + TickInterval: tickInterval, + Duration: testutil.WaitLong, + ReadMetrics: readMetrics, + WriteMetrics: writeMetrics, + SSH: false, + }) + + var logs strings.Builder + // Stop the test after one 'tick'. This will cause an EOF. + go func() { + <-time.After(cancelAfter) + cancel() + }() + require.NoError(t, runner.Run(ctx, "", &logs), "unexpected error calling Run()") + + t.Logf("read errors: %.0f\n", readMetrics.Errors()) + t.Logf("write errors: %.0f\n", writeMetrics.Errors()) + t.Logf("bytes read total: %.0f\n", readMetrics.Total()) + t.Logf("bytes written total: %.0f\n", writeMetrics.Total()) + + // We want to ensure the metrics are somewhat accurate. + assert.InDelta(t, bytesPerTick+fudgeWrite, writeMetrics.Total(), 0.1) + // Read is highly variable, depending on how far we read before stopping. + // Just ensure it's not zero. + assert.NotZero(t, readMetrics.Total()) + // Latency should report non-zero values. + assert.NotEmpty(t, readMetrics.Latencies()) + for _, l := range readMetrics.Latencies()[1:] { // skip the first one, which is always zero + assert.NotZero(t, l) + } + for _, l := range writeMetrics.Latencies()[1:] { // skip the first one, which is always zero + assert.NotZero(t, l) + } + assert.NotEmpty(t, writeMetrics.Latencies()) + // Should not report any errors! + assert.Zero(t, readMetrics.Errors()) + assert.Zero(t, writeMetrics.Errors()) }) - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - t.Cleanup(func() { - _ = agentCloser.Close() + t.Run("SSH", func(t *testing.T) { + t.Parallel() + // We need to stand up an in-memory coderd and run a fake workspace. + var ( + client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + firstUser = coderdtest.CreateFirstUser(t, client) + authToken = uuid.NewString() + agentName = "agent" + version = coderdtest.CreateTemplateVersion(t, client, firstUser.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.ProvisionComplete, + ProvisionApply: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + // Agent ID gets generated no matter what we say ¯\_(ツ)_/¯ + Name: agentName, + Auth: &proto.Agent_Token{ + Token: authToken, + }, + Apps: []*proto.App{}, + }}, + }}, + }, + }, + }}, + }) + template = coderdtest.CreateTemplate(t, client, firstUser.OrganizationID, version.ID) + _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + // In order to be picked up as a scaletest workspace, the workspace must be named specifically + ws = coderdtest.CreateWorkspace(t, client, firstUser.OrganizationID, template.ID, func(cwr *codersdk.CreateWorkspaceRequest) { + cwr.Name = "scaletest-test" + }) + _ = coderdtest.AwaitWorkspaceBuildJob(t, client, ws.LatestBuild.ID) + ) + + // We also need a running agent to run this test. + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(authToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(func() { + cancel() + _ = agentCloser.Close() + }) + + // Make sure the agent is connected before we go any further. + resources := coderdtest.AwaitWorkspaceAgents(t, client, ws.ID) + var agentID uuid.UUID + for _, res := range resources { + for _, agt := range res.Agents { + agentID = agt.ID + } + } + require.NotEqual(t, uuid.Nil, agentID, "did not expect agentID to be nil") + + // Now we can start the runner. + var ( + bytesPerTick = 1024 + tickInterval = 1000 * time.Millisecond + cancelAfter = 1500 * time.Millisecond + fudgeWrite = 2 // We send \r\n, which is two bytes + readMetrics = &testMetrics{} + writeMetrics = &testMetrics{} + ) + runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ + AgentID: agentID, + BytesPerTick: int64(bytesPerTick), + TickInterval: tickInterval, + Duration: testutil.WaitLong, + ReadMetrics: readMetrics, + WriteMetrics: writeMetrics, + SSH: true, + }) + + var logs strings.Builder + // Stop the test after one 'tick'. This will cause an EOF. + go func() { + <-time.After(cancelAfter) + cancel() + }() + require.NoError(t, runner.Run(ctx, "", &logs), "unexpected error calling Run()") + + t.Logf("read errors: %.0f\n", readMetrics.Errors()) + t.Logf("write errors: %.0f\n", writeMetrics.Errors()) + t.Logf("bytes read total: %.0f\n", readMetrics.Total()) + t.Logf("bytes written total: %.0f\n", writeMetrics.Total()) + + // We want to ensure the metrics are somewhat accurate. + assert.InDelta(t, bytesPerTick+fudgeWrite, writeMetrics.Total(), 0.1) + // Read is highly variable, depending on how far we read before stopping. + // Just ensure it's not zero. + assert.NotZero(t, readMetrics.Total()) + // Latency should report non-zero values. + assert.NotEmpty(t, readMetrics.Latencies()) + for _, l := range readMetrics.Latencies()[1:] { // skip the first one, which is always zero + assert.NotZero(t, l) + } + for _, l := range writeMetrics.Latencies()[1:] { // skip the first one, which is always zero + assert.NotZero(t, l) + } + assert.NotEmpty(t, writeMetrics.Latencies()) + // Should not report any errors! + assert.Zero(t, readMetrics.Errors()) + assert.Zero(t, writeMetrics.Errors()) }) - // We actually need to know the full user and not just the UserID / OrgID - user, err := client.User(ctx, firstUser.UserID.String()) - require.NoError(t, err, "get first user") - - // Make sure the agent is connected before we go any further. - resources := coderdtest.AwaitWorkspaceAgents(t, client, ws.ID) - var agentID uuid.UUID - for _, res := range resources { - for _, agt := range res.Agents { - agentID = agt.ID - } - } - require.NotEqual(t, uuid.Nil, agentID, "did not expect agentID to be nil") - - // Now we can start the runner. - var ( - bytesPerTick = 1024 - tickInterval = 1000 * time.Millisecond - cancelAfter = 1500 * time.Millisecond - fudgeWrite = 12 // The ReconnectingPTY payload incurs some overhead - ) - reg := prometheus.NewRegistry() - metrics := workspacetraffic.NewMetrics(reg, "username", "workspace_name", "agent_name") - runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ - AgentID: agentID, - AgentName: agentName, - WorkspaceName: ws.Name, - WorkspaceOwner: ws.OwnerName, - BytesPerTick: int64(bytesPerTick), - TickInterval: tickInterval, - Duration: testutil.WaitLong, - Registry: reg, - }, metrics) - - var logs strings.Builder - // Stop the test after one 'tick'. This will cause an EOF. - go func() { - <-time.After(cancelAfter) - cancel() - }() - require.NoError(t, runner.Run(ctx, "", &logs), "unexpected error calling Run()") - - // We want to ensure the metrics are somewhat accurate. - lvs := []string{user.Username, ws.Name, agentName} - assert.InDelta(t, bytesPerTick+fudgeWrite, toFloat64(t, metrics.BytesWrittenTotal.WithLabelValues(lvs...)), 0.1) - // Read is highly variable, depending on how far we read before stopping. - // Just ensure it's not zero. - assert.NotZero(t, bytesPerTick, toFloat64(t, metrics.BytesReadTotal.WithLabelValues(lvs...))) - // Latency should report non-zero values. - assert.NotZero(t, toFloat64(t, metrics.ReadLatencySeconds)) - assert.NotZero(t, toFloat64(t, metrics.WriteLatencySeconds)) - // Should not report any errors! - assert.Zero(t, toFloat64(t, metrics.ReadErrorsTotal.WithLabelValues(lvs...))) - assert.Zero(t, toFloat64(t, metrics.ReadErrorsTotal.WithLabelValues(lvs...))) } -// toFloat64 version of Prometheus' testutil.ToFloat64 that integrates with -// github.com/stretchr/testify/require and handles histograms (somewhat) -func toFloat64(t testing.TB, c prometheus.Collector) float64 { - var ( - m prometheus.Metric - mCount int - mChan = make(chan prometheus.Metric) - done = make(chan struct{}) - ) - - go func() { - for m = range mChan { - mCount++ - } - close(done) - }() - - c.Collect(mChan) - close(mChan) - <-done - - require.Equal(t, 1, mCount, "expected exactly 1 metric but got %d", mCount) - - pb := &dto.Metric{} - require.NoError(t, m.Write(pb), "unexpected error collecting metrics") - - if pb.Gauge != nil { - return pb.Gauge.GetValue() - } - if pb.Counter != nil { - return pb.Counter.GetValue() - } - if pb.Untyped != nil { - return pb.Untyped.GetValue() - } - if pb.Histogram != nil { - // If no samples, just return zero. - if pb.Histogram.GetSampleCount() == 0 { - return 0 - } - // Average is sufficient for testing purposes. - return pb.Histogram.GetSampleSum() / pb.Histogram.GetSampleCountFloat() - } - require.Fail(t, "collected a non-gauge/counter/untyped/histogram metric: %s", pb) - return 0 +type testMetrics struct { + sync.Mutex + errors float64 + latencies []float64 + total float64 +} + +var _ workspacetraffic.ConnMetrics = (*testMetrics)(nil) + +func (m *testMetrics) AddError(f float64) { + m.Lock() + defer m.Unlock() + m.errors += f +} + +func (m *testMetrics) ObserveLatency(f float64) { + m.Lock() + defer m.Unlock() + m.latencies = append(m.latencies, f) +} + +func (m *testMetrics) AddTotal(f float64) { + m.Lock() + defer m.Unlock() + m.total += f +} + +func (m *testMetrics) Total() float64 { + m.Lock() + defer m.Unlock() + return m.total +} + +func (m *testMetrics) Errors() float64 { + m.Lock() + defer m.Unlock() + return m.errors +} + +func (m *testMetrics) Latencies() []float64 { + m.Lock() + defer m.Unlock() + return m.latencies }