mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
4c1a32cd7c
Wire DERPTLSConfig through the CLI, SDK, tailnet, VPN client, agent, and health checks to allow custom TLS configuration for DERP connections. The main use case is to be able to set a custom CA and also present client certs (mTLS). See https://github.com/coder/tailscale/pull/105 for related changes. Adds three new global CLI flags: - `--client-tls-ca-file` / `CODER_CLIENT_TLS_CA_FILE` - `--client-tls-cert-file` / `CODER_CLIENT_TLS_CERT_FILE` - `--client-tls-key-file` / `CODER_CLIENT_TLS_KEY_FILE` Based on community PR #22695 by @ibdafna, with autogeneration issues fixed (protobuf version mismatches in .pb.go files, golden file regeneration, lint fixes). > [!NOTE] > This PR was authored by Coder Agents on behalf of a Coder team member. <details> <summary>Relationship to #22695</summary> This is a clean reimplementation of the changes from #22695 on top of current `main`, with the following differences: - **Removed**: Accidental protobuf version changes in `.pb.go` files (contributor had `protoc v6.33.4` vs project's `protoc v4.23.4`) - **Added**: Properly regenerated golden files and docs via `make gen` - **Fixed**: Lint issue (`var-declaration` revive warning on explicit type in `createHTTPClient`) - All meaningful code changes are identical to the original PR </details>
476 lines
13 KiB
Go
476 lines
13 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"runtime"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
|
|
"github.com/coder/coder/v2/cli/cliui"
|
|
"github.com/coder/coder/v2/cli/telemetry"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/pretty"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
if runtime.GOOS == "windows" {
|
|
// Don't run goleak on windows tests, they're super flaky right now.
|
|
// See: https://github.com/coder/coder/issues/8954
|
|
os.Exit(m.Run())
|
|
}
|
|
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
|
|
}
|
|
|
|
func Test_formatExamples(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
examples []Example
|
|
wantMatches []string
|
|
}{
|
|
{
|
|
name: "No examples",
|
|
examples: nil,
|
|
wantMatches: nil,
|
|
},
|
|
{
|
|
name: "Output examples",
|
|
examples: []Example{
|
|
{
|
|
Description: "Hello world.",
|
|
Command: "echo hello",
|
|
},
|
|
{
|
|
Description: "Bye bye.",
|
|
Command: "echo bye",
|
|
},
|
|
},
|
|
wantMatches: []string{
|
|
"Hello world", "echo hello",
|
|
"Bye bye", "echo bye",
|
|
},
|
|
},
|
|
{
|
|
name: "No description outputs commands",
|
|
examples: []Example{
|
|
{
|
|
Command: "echo hello",
|
|
},
|
|
},
|
|
wantMatches: []string{
|
|
"echo hello",
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := FormatExamples(tt.examples...)
|
|
if len(tt.wantMatches) == 0 {
|
|
require.Empty(t, got)
|
|
} else {
|
|
for _, want := range tt.wantMatches {
|
|
require.Contains(t, got, want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoOutput", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{}
|
|
cmd, err := r.Command(nil)
|
|
require.NoError(t, err)
|
|
var buf bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Stderr = &buf
|
|
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
// Provider a version that will not match!
|
|
codersdk.BuildVersionHeader: []string{"v2.0.0"},
|
|
},
|
|
Body: io.NopCloser(nil),
|
|
}, nil
|
|
}), inv, "v2.0.0", nil)
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
res, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
require.Equal(t, "", buf.String())
|
|
})
|
|
|
|
t.Run("CustomUpgradeMessage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := &RootCmd{}
|
|
|
|
cmd, err := r.Command(nil)
|
|
require.NoError(t, err)
|
|
|
|
var buf bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Stderr = &buf
|
|
expectedUpgradeMessage := "My custom upgrade message"
|
|
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
// Provider a version that will not match!
|
|
codersdk.BuildVersionHeader: []string{"v1.0.0"},
|
|
},
|
|
Body: io.NopCloser(nil),
|
|
}, nil
|
|
}), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
|
return codersdk.BuildInfoResponse{
|
|
UpgradeMessage: expectedUpgradeMessage,
|
|
}, nil
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
res, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
// Run this twice to ensure the upgrade message is only printed once.
|
|
res, err = rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
|
|
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
|
require.Equal(t, expectedOutput, buf.String())
|
|
})
|
|
|
|
t.Run("ServerStableVersion", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{}
|
|
cmd, err := r.Command(nil)
|
|
require.NoError(t, err)
|
|
var buf bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Stderr = &buf
|
|
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
|
},
|
|
Body: io.NopCloser(nil),
|
|
}, nil
|
|
}), inv, "v2.31.0", nil)
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
res, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
require.Empty(t, buf.String())
|
|
})
|
|
}
|
|
|
|
func Test_serverVersionMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
version string
|
|
expected string
|
|
}{
|
|
{"Stable", "v2.31.0", ""},
|
|
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
|
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
|
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
|
{"Empty", "", ""},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
Body: io.NopCloser(nil),
|
|
}, nil
|
|
}), &serpent.Invocation{
|
|
Command: &serpent.Command{
|
|
Use: "test",
|
|
Options: serpent.OptionSet{{
|
|
Name: "bananas",
|
|
Description: "hey",
|
|
}},
|
|
},
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
res, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
resp := req.Header.Get(codersdk.CLITelemetryHeader)
|
|
require.NotEmpty(t, resp)
|
|
data, err := base64.StdEncoding.DecodeString(resp)
|
|
require.NoError(t, err)
|
|
var ti telemetry.Invocation
|
|
err = json.Unmarshal(data, &ti)
|
|
require.NoError(t, err)
|
|
require.Equal(t, ti.Command, "test")
|
|
}
|
|
|
|
//nolint:tparallel,paralleltest // This test modifies environment variables.
|
|
func TestPrintDeprecatedOptions(t *testing.T) {
|
|
newValue := serpent.StringOf(new(string))
|
|
|
|
// Both the "new" option and the deprecated option point at the
|
|
// same Value, mirroring how codersdk/deployment.go wires the
|
|
// CODER_EMAIL_* / CODER_NOTIFICATIONS_EMAIL_* pairs.
|
|
newOpt := serpent.Option{
|
|
Name: "new-option",
|
|
Flag: "new-option",
|
|
Env: "CODER_TEST_NEW_OPTION",
|
|
Value: newValue,
|
|
}
|
|
deprecatedOpt := serpent.Option{
|
|
Name: "old-option",
|
|
Flag: "old-option",
|
|
Env: "CODER_TEST_OLD_OPTION",
|
|
Value: newValue, // same pointer
|
|
UseInstead: serpent.OptionSet{newOpt},
|
|
}
|
|
|
|
makeCmd := func(opts serpent.OptionSet) *serpent.Command {
|
|
return &serpent.Command{
|
|
Use: "test",
|
|
Options: opts,
|
|
Middleware: PrintDeprecatedOptions(),
|
|
Handler: func(_ *serpent.Invocation) error {
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
t.Run("EnvOnlyNew_NoWarning", func(t *testing.T) {
|
|
t.Setenv("CODER_TEST_NEW_OPTION", "val")
|
|
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Empty(t, stderr.String(),
|
|
"setting only the new env var should not produce a deprecation warning")
|
|
})
|
|
|
|
t.Run("EnvOnlyOld_Warning", func(t *testing.T) {
|
|
t.Setenv("CODER_TEST_OLD_OPTION", "val")
|
|
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stderr.String(), "is deprecated",
|
|
"setting the deprecated env var should produce a warning")
|
|
})
|
|
|
|
t.Run("EnvBothSet_Warning", func(t *testing.T) {
|
|
t.Setenv("CODER_TEST_NEW_OPTION", "new")
|
|
t.Setenv("CODER_TEST_OLD_OPTION", "old")
|
|
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stderr.String(), "is deprecated",
|
|
"setting both env vars should still warn about the deprecated one")
|
|
})
|
|
|
|
t.Run("DeprecatedEnvAndNewFlag_Warning", func(t *testing.T) {
|
|
t.Setenv("CODER_TEST_OLD_OPTION", "val")
|
|
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke("--new-option", "val")
|
|
inv.Environ = serpent.ParseEnviron(os.Environ(), "")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stderr.String(), "`CODER_TEST_OLD_OPTION` is deprecated",
|
|
"setting the deprecated env var should still warn even if the replacement flag overrides the value")
|
|
require.NotContains(t, stderr.String(), "`--old-option` is deprecated",
|
|
"the deprecated environment variable should not be misreported as a deprecated flag")
|
|
})
|
|
|
|
t.Run("FlagOnlyNew_NoWarning", func(t *testing.T) {
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke("--new-option", "val")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Empty(t, stderr.String(),
|
|
"passing only the new flag should not produce a deprecation warning")
|
|
})
|
|
|
|
t.Run("FlagOnlyOld_Warning", func(t *testing.T) {
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke("--old-option", "val")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stderr.String(), "is deprecated",
|
|
"passing the deprecated flag should produce a warning")
|
|
})
|
|
|
|
t.Run("CODER_EMAIL_FROM_NoWarning", func(t *testing.T) {
|
|
t.Setenv("CODER_EMAIL_FROM", "noreply@example.com")
|
|
|
|
deploymentValues := new(codersdk.DeploymentValues)
|
|
cmd := makeCmd(deploymentValues.Options())
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Environ = serpent.ParseEnviron([]string{"CODER_EMAIL_FROM=noreply@example.com"}, "")
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.NotContains(t, stderr.String(), "is deprecated",
|
|
"setting only CODER_EMAIL_FROM should not produce any deprecation warning")
|
|
})
|
|
|
|
t.Run("NothingSet_NoWarning", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cmd := makeCmd(serpent.OptionSet{newOpt, deprecatedOpt})
|
|
var stderr bytes.Buffer
|
|
inv := cmd.Invoke()
|
|
inv.Stderr = &stderr
|
|
err := inv.Run()
|
|
require.NoError(t, err)
|
|
require.Empty(t, stderr.String(),
|
|
"setting nothing should not produce a deprecation warning")
|
|
})
|
|
}
|
|
|
|
func Test_wrapTransportWithEntitlementsCheck(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
lines := []string{"First Warning", "Second Warning"}
|
|
var buf bytes.Buffer
|
|
rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
codersdk.EntitlementsWarningHeader: lines,
|
|
},
|
|
Body: io.NopCloser(nil),
|
|
}, nil
|
|
}), &buf)
|
|
res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
|
|
require.NoError(t, err)
|
|
defer res.Body.Close()
|
|
expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]),
|
|
pretty.Sprint(cliui.DefaultStyles.Warn, lines[1]))
|
|
require.Equal(t, expectedOutput, buf.String())
|
|
}
|
|
|
|
func Test_ensureTLSConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("NoFilesSpecified", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{}
|
|
err := r.ensureTLSConfig()
|
|
require.NoError(t, err)
|
|
require.Nil(t, r.tlsConfig)
|
|
})
|
|
|
|
t.Run("OnlyCertFileErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{
|
|
tlsClientCertFile: "/some/cert.pem",
|
|
}
|
|
err := r.ensureTLSConfig()
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "must be specified together")
|
|
})
|
|
|
|
t.Run("OnlyKeyFileErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{
|
|
tlsClientKeyFile: "/some/key.pem",
|
|
}
|
|
err := r.ensureTLSConfig()
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "must be specified together")
|
|
})
|
|
|
|
t.Run("InvalidCAFileErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
r := &RootCmd{
|
|
tlsCAFile: "/nonexistent/ca.pem",
|
|
}
|
|
err := r.ensureTLSConfig()
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "read TLS CA file")
|
|
})
|
|
|
|
t.Run("AlreadySetSkipsLoading", func(t *testing.T) {
|
|
t.Parallel()
|
|
existingConfig := &tls.Config{MinVersion: tls.VersionTLS13}
|
|
r := &RootCmd{
|
|
tlsConfig: existingConfig,
|
|
tlsClientCertFile: "/some/cert.pem",
|
|
}
|
|
err := r.ensureTLSConfig()
|
|
require.NoError(t, err)
|
|
require.Same(t, existingConfig, r.tlsConfig)
|
|
})
|
|
|
|
t.Run("InvalidPEMContentErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
tmpFile, err := os.CreateTemp("", "invalid-ca-*.pem")
|
|
require.NoError(t, err)
|
|
defer os.Remove(tmpFile.Name())
|
|
_, err = tmpFile.WriteString("this is not valid PEM data")
|
|
require.NoError(t, err)
|
|
require.NoError(t, tmpFile.Close())
|
|
|
|
r := &RootCmd{
|
|
tlsCAFile: tmpFile.Name(),
|
|
}
|
|
err = r.ensureTLSConfig()
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "failed to parse CA certificate")
|
|
})
|
|
}
|