Files
coder/cli/root_internal_test.go
Thomas Kosiewski 26a0805dcd fix(cli): isolate root HTTP transports (#25430)
The CLI root client shared `http.DefaultTransport` for normal API
requests and for the version-check build-info request. In parallel
tests, other clients can close idle connections on that process-global
transport, which can fail the Boundary license check before the AGPL 404
handling runs.

`TestBoundaryLicenseVerification/AGPLDeployment` configures a proxy that
returns `404` from `/api/v2/entitlements`, which `verifyLicense()` maps
to the expected AGPL deployment error. However, `clitest.SetupConfig()`
only writes the URL and session token to disk. It does not pass the
test's isolated `proxyClient.HTTPClient` into the CLI invocation, so
`coder boundary` builds a fresh client through `RootCmd.InitClient()`.
Before this change, that fresh client used `http.DefaultTransport`; if
another parallel test closed idle connections on the shared transport
while the entitlement request was in flight, Go returned `http:
CloseIdleConnections called` instead of the proxy's `404`. The command
then failed with `failed to get entitlements`, and the test never
reached the expected AGPL error path.

Clone the default transport for each CLI root HTTP client and for the
unwrapped build-info client, preserving the configured TLS settings when
present. Each CLI invocation now gets its own transport instance, so
cleanup from unrelated parallel tests cannot interrupt its entitlement
or build-info requests.

Closes https://github.com/coder/internal/issues/1538

<details>
<summary>Coder Agents notes</summary>

Generated by Coder Agents for Linear ENG-2705.

Local validation:

- `go test ./cli -run
'TestNewHTTPTransport|Test_ensureTLSConfig|Test_wrapTransportWithVersionCheck'
-count=1`
- `go test ./enterprise/cli -run
TestBoundaryLicenseVerification/AGPLDeployment -count=20 -parallel=16`
- `go test ./cli ./enterprise/cli`
- `make lint`
- `go test ./enterprise/cli -run '^TestBoundaryLicenseVerification$'
-count=50 -parallel=16`
- pre-commit hook during `git commit`

</details>
2026-05-21 16:51:34 +02:00

498 lines
14 KiB
Go

package cli
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/cli/telemetry"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func TestMain(m *testing.M) {
if runtime.GOOS == "windows" {
// Don't run goleak on windows tests, they're super flaky right now.
// See: https://github.com/coder/coder/issues/8954
os.Exit(m.Run())
}
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
func Test_formatExamples(t *testing.T) {
t.Parallel()
tests := []struct {
name string
examples []Example
wantMatches []string
}{
{
name: "No examples",
examples: nil,
wantMatches: nil,
},
{
name: "Output examples",
examples: []Example{
{
Description: "Hello world.",
Command: "echo hello",
},
{
Description: "Bye bye.",
Command: "echo bye",
},
},
wantMatches: []string{
"Hello world", "echo hello",
"Bye bye", "echo bye",
},
},
{
name: "No description outputs commands",
examples: []Example{
{
Command: "echo hello",
},
},
wantMatches: []string{
"echo hello",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := FormatExamples(tt.examples...)
if len(tt.wantMatches) == 0 {
require.Empty(t, got)
} else {
for _, want := range tt.wantMatches {
require.Contains(t, got, want)
}
}
})
}
}
func Test_wrapTransportWithVersionCheck(t *testing.T) {
t.Parallel()
t.Run("NoOutput", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
// Provider a version that will not match!
codersdk.BuildVersionHeader: []string{"v2.0.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.0.0", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, "", buf.String())
})
t.Run("CustomUpgradeMessage", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
expectedUpgradeMessage := "My custom upgrade message"
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
// Provider a version that will not match!
codersdk.BuildVersionHeader: []string{"v1.0.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
return codersdk.BuildInfoResponse{
UpgradeMessage: expectedUpgradeMessage,
}, nil
})
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
// Run this twice to ensure the upgrade message is only printed once.
res, err = rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
require.Equal(t, expectedOutput, buf.String())
})
t.Run("ServerStableVersion", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
codersdk.BuildVersionHeader: []string{"v2.31.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.31.0", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
require.Empty(t, buf.String())
})
}
func Test_serverVersionMessage(t *testing.T) {
t.Parallel()
cases := []struct {
name string
version string
expected string
}{
{"Stable", "v2.31.0", ""},
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
{"Empty", "", ""},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, c.expected, serverVersionMessage(c.version))
})
}
}
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
t.Parallel()
rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(nil),
}, nil
}), &serpent.Invocation{
Command: &serpent.Command{
Use: "test",
Options: serpent.OptionSet{{
Name: "bananas",
Description: "hey",
}},
},
})
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
resp := req.Header.Get(codersdk.CLITelemetryHeader)
require.NotEmpty(t, resp)
data, err := base64.StdEncoding.DecodeString(resp)
require.NoError(t, err)
var ti telemetry.Invocation
err = json.Unmarshal(data, &ti)
require.NoError(t, err)
require.Equal(t, ti.Command, "test")
}
//nolint:tparallel,paralleltest // This test modifies environment variables.
func TestPrintDeprecatedOptions(t *testing.T) {
newValue := serpent.StringOf(new(string))
// Both the "new" option and the deprecated option point at the
// same Value, mirroring how codersdk/deployment.go wires the
// CODER_EMAIL_* / CODER_NOTIFICATIONS_EMAIL_* pairs.
newOpt := serpent.Option{
Name: "new-option",
Flag: "new-option",
Env: "CODER_TEST_NEW_OPTION",
Value: newValue,
}
deprecatedOpt := serpent.Option{
Name: "old-option",
Flag: "old-option",
Env: "CODER_TEST_OLD_OPTION",
Value: newValue, // same pointer
UseInstead: serpent.OptionSet{newOpt},
}
makeCmd := func(opts serpent.OptionSet) *serpent.Command {
return &serpent.Command{
Use: "test",
Options: opts,
Middleware: PrintDeprecatedOptions(),
Handler: func(_ *serpent.Invocation) error {
return nil
},
}
}
t.Run("EnvOnlyNew_NoWarning", func(t *testing.T) {
t.Setenv("CODER_TEST_NEW_OPTION", "val")
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke()
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Empty(t, stderr.String(),
"setting only the new env var should not produce a deprecation warning")
})
t.Run("EnvOnlyOld_Warning", func(t *testing.T) {
t.Setenv("CODER_TEST_OLD_OPTION", "val")
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke()
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stderr.String(), "is deprecated",
"setting the deprecated env var should produce a warning")
})
t.Run("EnvBothSet_Warning", func(t *testing.T) {
t.Setenv("CODER_TEST_NEW_OPTION", "new")
t.Setenv("CODER_TEST_OLD_OPTION", "old")
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke()
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stderr.String(), "is deprecated",
"setting both env vars should still warn about the deprecated one")
})
t.Run("DeprecatedEnvAndNewFlag_Warning", func(t *testing.T) {
t.Setenv("CODER_TEST_OLD_OPTION", "val")
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke("--new-option", "val")
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stderr.String(), "`CODER_TEST_OLD_OPTION` is deprecated",
"setting the deprecated env var should still warn even if the replacement flag overrides the value")
require.NotContains(t, stderr.String(), "`--old-option` is deprecated",
"the deprecated environment variable should not be misreported as a deprecated flag")
})
t.Run("FlagOnlyNew_NoWarning", func(t *testing.T) {
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke("--new-option", "val")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Empty(t, stderr.String(),
"passing only the new flag should not produce a deprecation warning")
})
t.Run("FlagOnlyOld_Warning", func(t *testing.T) {
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke("--old-option", "val")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stderr.String(), "is deprecated",
"passing the deprecated flag should produce a warning")
})
t.Run("CODER_EMAIL_FROM_NoWarning", func(t *testing.T) {
t.Setenv("CODER_EMAIL_FROM", "noreply@example.com")
deploymentValues := new(codersdk.DeploymentValues)
cmd := makeCmd(deploymentValues.Options())
var stderr bytes.Buffer
inv := cmd.Invoke()
inv.Environ = serpent.ParseEnviron([]string{"CODER_EMAIL_FROM=noreply@example.com"}, "")
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.NotContains(t, stderr.String(), "is deprecated",
"setting only CODER_EMAIL_FROM should not produce any deprecation warning")
})
t.Run("NothingSet_NoWarning", func(t *testing.T) {
t.Parallel()
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
var stderr bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &stderr
err := inv.Run()
require.NoError(t, err)
require.Empty(t, stderr.String(),
"setting nothing should not produce a deprecation warning")
})
}
func Test_wrapTransportWithEntitlementsCheck(t *testing.T) {
t.Parallel()
lines := []string{"First Warning", "Second Warning"}
var buf bytes.Buffer
rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
codersdk.EntitlementsWarningHeader: lines,
},
Body: io.NopCloser(nil),
}, nil
}), &buf)
res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
require.NoError(t, err)
defer res.Body.Close()
expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]),
pretty.Sprint(cliui.DefaultStyles.Warn, lines[1]))
require.Equal(t, expectedOutput, buf.String())
}
func Test_ensureTLSConfig(t *testing.T) {
t.Parallel()
t.Run("NoFilesSpecified", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
err := r.ensureTLSConfig()
require.NoError(t, err)
require.Nil(t, r.tlsConfig)
})
t.Run("OnlyCertFileErrors", func(t *testing.T) {
t.Parallel()
r := &RootCmd{
tlsClientCertFile: "/some/cert.pem",
}
err := r.ensureTLSConfig()
require.Error(t, err)
require.Contains(t, err.Error(), "must be specified together")
})
t.Run("OnlyKeyFileErrors", func(t *testing.T) {
t.Parallel()
r := &RootCmd{
tlsClientKeyFile: "/some/key.pem",
}
err := r.ensureTLSConfig()
require.Error(t, err)
require.Contains(t, err.Error(), "must be specified together")
})
t.Run("InvalidCAFileErrors", func(t *testing.T) {
t.Parallel()
r := &RootCmd{
tlsCAFile: "/nonexistent/ca.pem",
}
err := r.ensureTLSConfig()
require.Error(t, err)
require.Contains(t, err.Error(), "read TLS CA file")
})
t.Run("AlreadySetSkipsLoading", func(t *testing.T) {
t.Parallel()
existingConfig := &tls.Config{MinVersion: tls.VersionTLS13}
r := &RootCmd{
tlsConfig: existingConfig,
tlsClientCertFile: "/some/cert.pem",
}
err := r.ensureTLSConfig()
require.NoError(t, err)
require.Same(t, existingConfig, r.tlsConfig)
})
t.Run("InvalidPEMContentErrors", func(t *testing.T) {
t.Parallel()
tmpFile, err := os.CreateTemp("", "invalid-ca-*.pem")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString("this is not valid PEM data")
require.NoError(t, err)
require.NoError(t, tmpFile.Close())
r := &RootCmd{
tlsCAFile: tmpFile.Name(),
}
err = r.ensureTLSConfig()
require.Error(t, err)
require.Contains(t, err.Error(), "failed to parse CA certificate")
})
}
func TestNewHTTPTransportClonesDefaultTransport(t *testing.T) {
t.Parallel()
transport, err := newHTTPTransport(nil)
require.NoError(t, err)
require.NotSame(t, http.DefaultTransport, transport)
require.IsType(t, &http.Transport{}, transport)
}
func TestNewHTTPTransportAppliesTLSConfigToClone(t *testing.T) {
t.Parallel()
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13}
transport, err := newHTTPTransport(tlsConfig)
require.NoError(t, err)
require.NotSame(t, http.DefaultTransport, transport)
httpTransport, ok := transport.(*http.Transport)
require.True(t, ok)
require.Same(t, tlsConfig, httpTransport.TLSClientConfig)
}