diff --git a/cli/root.go b/cli/root.go index 24228114a3..95689241cc 100644 --- a/cli/root.go +++ b/cli/root.go @@ -684,6 +684,7 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*http.Client, error) { transport := http.DefaultTransport transport = wrapTransportWithTelemetryHeader(transport, inv) + transport = wrapTransportWithUserAgentHeader(transport, inv) if !r.noVersionCheck { transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) { // Create a new client without any wrapped transport @@ -1497,6 +1498,22 @@ func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent. }) } +// wrapTransportWithUserAgentHeader sets a User-Agent header for all CLI requests +// that includes the CLI version, os/arch, and the specific command being run. +func wrapTransportWithUserAgentHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper { + var ( + userAgent string + once sync.Once + ) + return roundTripper(func(req *http.Request) (*http.Response, error) { + once.Do(func() { + userAgent = fmt.Sprintf("coder-cli/%s (%s/%s; %s)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH, inv.Command.FullName()) + }) + req.Header.Set("User-Agent", userAgent) + return transport.RoundTrip(req) + }) +} + type roundTripper func(req *http.Request) (*http.Response, error) func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/cli/root_test.go b/cli/root_test.go index 4e4c9c2399..10642d6c99 100644 --- a/cli/root_test.go +++ b/cli/root_test.go @@ -380,3 +380,59 @@ func agentClientCommand(clientRef **agentsdk.Client) *serpent.Command { agentAuth.AttachOptions(cmd, false) return cmd } + +func TestWrapTransportWithUserAgentHeader(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + cmdArgs []string + cmdEnv map[string]string + expectedUserAgentHeader string + }{ + { + name: "top-level command", + cmdArgs: []string{"login"}, + expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder login)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH), + }, + { + name: "nested commands", + cmdArgs: []string{"templates", "list"}, + expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder templates list)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH), + }, + { + name: "does not include positional args, flags, or env", + cmdArgs: []string{"templates", "push", "my-template", "-d", "/path/to/template", "--yes", "--var", "myvar=myvalue"}, + cmdEnv: map[string]string{"SECRET_KEY": "secret_value"}, + expectedUserAgentHeader: fmt.Sprintf("coder-cli/%s (%s/%s; coder templates push)", buildinfo.Version(), runtime.GOOS, runtime.GOARCH), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ch := make(chan string, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case ch <- r.Header.Get("User-Agent"): + default: // already sent + } + })) + t.Cleanup(srv.Close) + + args := append([]string{}, tc.cmdArgs...) + inv, _ := clitest.New(t, args...) + inv.Environ.Set("CODER_URL", srv.URL) + for k, v := range tc.cmdEnv { + inv.Environ.Set(k, v) + } + + ctx := testutil.Context(t, testutil.WaitShort) + _ = inv.WithContext(ctx).Run() // Ignore error as we only care about headers. + + actual := testutil.RequireReceive(ctx, t, ch) + require.Equal(t, tc.expectedUserAgentHeader, actual, "User-Agent should match expected format exactly") + }) + } +}