mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: Support --header for CLI commands to support proxies (#4008)
Fixes #3527.
This commit is contained in:
@@ -61,6 +61,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
|
||||
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
|
||||
v, ok := os.LookupEnv(env)
|
||||
if !ok || v == "" {
|
||||
if v == "" {
|
||||
def = []string{}
|
||||
} else {
|
||||
def = strings.Split(v, ",")
|
||||
}
|
||||
}
|
||||
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
if ok {
|
||||
|
||||
+4
-1
@@ -66,7 +66,10 @@ func login() *cobra.Command {
|
||||
serverURL.Scheme = "https"
|
||||
}
|
||||
|
||||
client := codersdk.New(serverURL)
|
||||
client, err := createUnauthenticatedClient(cmd, serverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to check the version of the server prior to logging in.
|
||||
// It may be useful to warn the user if they are trying to login
|
||||
|
||||
+40
-1
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -41,6 +42,7 @@ const (
|
||||
varAgentToken = "agent-token"
|
||||
varAgentURL = "agent-url"
|
||||
varGlobalConfig = "global-config"
|
||||
varHeader = "header"
|
||||
varNoOpen = "no-open"
|
||||
varNoVersionCheck = "no-version-warning"
|
||||
varNoFeatureWarning = "no-feature-warning"
|
||||
@@ -174,6 +176,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
|
||||
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varAgentURL)
|
||||
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.")
|
||||
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"")
|
||||
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varForceTty)
|
||||
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.")
|
||||
@@ -237,8 +240,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
client, err := createUnauthenticatedClient(cmd, serverURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.SessionToken = token
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) {
|
||||
client := codersdk.New(serverURL)
|
||||
client.SessionToken = strings.TrimSpace(token)
|
||||
headers, err := cmd.Flags().GetStringArray(varHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
headers: map[string]string{},
|
||||
}
|
||||
for _, header := range headers {
|
||||
parts := strings.SplitN(header, "=", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, xerrors.Errorf("split header %q had less than two parts", header)
|
||||
}
|
||||
transport.headers[parts[0]] = parts[1]
|
||||
}
|
||||
client.HTTPClient.Transport = transport
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -530,3 +557,15 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type headerTransport struct {
|
||||
transport http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
return h.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -129,4 +132,25 @@ func TestRoot(t *testing.T) {
|
||||
require.Contains(t, output, buildinfo.Version(), "has version")
|
||||
require.Contains(t, output, buildinfo.ExternalURL(), "has url")
|
||||
})
|
||||
|
||||
t.Run("Header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
done := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
|
||||
w.WriteHeader(http.StatusGone)
|
||||
select {
|
||||
case <-done:
|
||||
close(done)
|
||||
default:
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
buf := new(bytes.Buffer)
|
||||
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
|
||||
cmd.SetOut(buf)
|
||||
// This won't succeed, because we're using the login cmd to assert requests.
|
||||
_ = cmd.Execute()
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user