mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
a6b9a25f82
## Summary
Fixes cross-replica chat relay failing with:
```
failed to open initial relay for chat stream
error= dial relay stream: - failed to WebSocket dial: expected handshake response status code 101 but got 200
failed to open relay for message parts
error= dial relay stream: - failed to WebSocket dial: expected handshake response status code 101 but got 200
```
Subscribers see accurate `status=running` (delivered via pubsub) but
miss all in-progress `message_part` events (delivered only via the relay
WebSocket that never connects).
## Root cause
`redirectToAccessURL` in `cli/server.go` redirects any request whose
`Host` header doesn't match the access URL. The enterprise chat relay
dials another replica directly via its DERP relay address (e.g.
`http://10.0.0.2:8080`), so the `Host` header is the pod IP — not the
access URL.
This triggers a **307 redirect** to the access URL. The WebSocket
library follows the redirect, but the second request is a plain GET —
`Connection: Upgrade` and `Upgrade: websocket` headers are **not carried
over** by HTTP redirect semantics. The load-balanced access URL routes
the plain GET to any replica, which serves the SPA catch-all handler and
returns **HTTP 200 with `index.html`**.
The WebSocket library then fails: `expected handshake response status
code 101 but got 200`.
DERP mesh already has an exemption for this exact scenario
(`isDERPPath`). Chat relay was added later and didn't get one.
## Fix
Bypass `redirectToAccessURL` for requests that carry the
`X-Coder-Relay-Source-Replica` header, which the enterprise relay
already sets on every request (`enterprise/coderd/chatd/chatd.go:573`).
## Sequence diagram
**Before (broken):**
```
Replica A (subscriber) Replica B (worker) Load Balancer
| | |
|--- WS dial pod-ip:8080 ----->| |
| |-- 307 redirect to LB --->|
| | |
|<----------- plain GET (no Upgrade headers) ------------->|
| | |-- routes to any replica
|<----------- 200 index.html -------------------------------|
| |
X 'expected 101 but got 200' |
```
**After (fixed):**
```
Replica A (subscriber) Replica B (worker)
| |
|--- WS dial pod-ip:8080 ----->|
| (X-Coder-Relay-Source- |
| Replica header set) |
| |-- bypass redirect
|<--------- 101 Upgrade ------|
|<==== message_part events ====|
```
400 lines
10 KiB
Go
400 lines
10 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/spf13/pflag"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/sloghuman"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
func Test_configureServerTLS(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("DefaultNoInsecureCiphers", func(t *testing.T) {
|
|
t.Parallel()
|
|
logger := testutil.Logger(t)
|
|
cfg, err := configureServerTLS(context.Background(), logger, "tls12", "none", nil, nil, "", nil, false)
|
|
require.NoError(t, err)
|
|
|
|
require.NotEmpty(t, cfg)
|
|
|
|
insecureCiphers := tls.InsecureCipherSuites()
|
|
for _, cipher := range cfg.CipherSuites {
|
|
for _, insecure := range insecureCiphers {
|
|
if cipher == insecure.ID {
|
|
t.Logf("Insecure cipher found by default: %s", insecure.Name)
|
|
t.Fail()
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func Test_configureCipherSuites(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cipherNames := func(ciphers []*tls.CipherSuite) []string {
|
|
var names []string
|
|
for _, c := range ciphers {
|
|
names = append(names, c.Name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
cipherIDs := func(ciphers []*tls.CipherSuite) []uint16 {
|
|
var ids []uint16
|
|
for _, c := range ciphers {
|
|
ids = append(ids, c.ID)
|
|
}
|
|
return ids
|
|
}
|
|
|
|
cipherByName := func(cipher string) *tls.CipherSuite {
|
|
for _, c := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) {
|
|
if cipher == c.Name {
|
|
return c
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
wantErr string
|
|
wantWarnings []string
|
|
inputCiphers []string
|
|
minTLS uint16
|
|
maxTLS uint16
|
|
allowInsecure bool
|
|
expectCiphers []uint16
|
|
}{
|
|
{
|
|
name: "AllSecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: cipherNames(tls.CipherSuites()),
|
|
wantWarnings: []string{},
|
|
expectCiphers: cipherIDs(tls.CipherSuites()),
|
|
},
|
|
{
|
|
name: "AllowInsecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"insecure tls cipher specified",
|
|
},
|
|
expectCiphers: append(cipherIDs(tls.CipherSuites()), tls.InsecureCipherSuites()[0].ID),
|
|
},
|
|
{
|
|
name: "AllInsecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), cipherNames(tls.InsecureCipherSuites())...),
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"insecure tls cipher specified",
|
|
},
|
|
expectCiphers: append(cipherIDs(tls.CipherSuites()), cipherIDs(tls.InsecureCipherSuites())...),
|
|
},
|
|
{
|
|
// Providing ciphers that are not compatible with any tls version
|
|
// enabled should generate a warning.
|
|
name: "ExcessiveCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS11,
|
|
inputCiphers: []string{
|
|
"TLS_RSA_WITH_AES_128_CBC_SHA",
|
|
// Only for TLS 1.3
|
|
"TLS_AES_128_GCM_SHA256",
|
|
},
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"cipher not supported for tls versions",
|
|
},
|
|
expectCiphers: cipherIDs([]*tls.CipherSuite{
|
|
cipherByName("TLS_RSA_WITH_AES_128_CBC_SHA"),
|
|
cipherByName("TLS_AES_128_GCM_SHA256"),
|
|
}),
|
|
},
|
|
// Errors
|
|
{
|
|
name: "NotRealCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: []string{"RSA-Fake"},
|
|
wantErr: "unsupported tls ciphers",
|
|
},
|
|
{
|
|
name: "NoCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
wantErr: "no tls ciphers supported",
|
|
},
|
|
{
|
|
name: "InsecureNotAllowed",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
|
|
wantErr: "insecure tls ciphers specified",
|
|
},
|
|
{
|
|
name: "TLS1.3",
|
|
minTLS: tls.VersionTLS13,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: cipherNames(tls.CipherSuites()),
|
|
wantErr: "'--tls-ciphers' cannot be specified when using minimum tls version 1.3",
|
|
},
|
|
{
|
|
name: "TLSUnsupported",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 only supports tls 1.2
|
|
inputCiphers: []string{"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"},
|
|
wantErr: "no tls ciphers supported for tls versions",
|
|
},
|
|
{
|
|
name: "Min>Max",
|
|
minTLS: tls.VersionTLS13,
|
|
maxTLS: tls.VersionTLS12,
|
|
wantErr: "minimum tls version (TLS 1.3) cannot be greater than maximum tls version (TLS 1.2)",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
var out bytes.Buffer
|
|
logger := slog.Make(sloghuman.Sink(&out))
|
|
|
|
found, err := configureCipherSuites(ctx, logger, tt.inputCiphers, tt.allowInsecure, tt.minTLS, tt.maxTLS)
|
|
if tt.wantErr != "" {
|
|
require.ErrorContains(t, err, tt.wantErr)
|
|
} else {
|
|
require.NoError(t, err, "no error")
|
|
require.ElementsMatch(t, tt.expectCiphers, found, "expected ciphers")
|
|
if len(tt.wantWarnings) > 0 {
|
|
logger.Sync()
|
|
for _, w := range tt.wantWarnings {
|
|
assert.Contains(t, out.String(), w, "expected warning")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRedirectHTTPToHTTPSDeprecation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
name string
|
|
environ serpent.Environ
|
|
flags []string
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "AllUnset",
|
|
environ: serpent.Environ{},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP=true",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "true"}},
|
|
flags: []string{},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=true",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "true"}},
|
|
flags: []string{},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP=false",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "false"}},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=false",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "false"}},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "--tls-redirect-http-to-https",
|
|
environ: serpent.Environ{},
|
|
flags: []string{"--tls-redirect-http-to-https"},
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testcases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
|
_ = flags.Bool("tls-redirect-http-to-https", true, "")
|
|
err := flags.Parse(tc.flags)
|
|
require.NoError(t, err)
|
|
inv := (&serpent.Invocation{Environ: tc.environ}).WithTestParsedFlags(t, flags)
|
|
cfg := &codersdk.DeploymentValues{}
|
|
opts := cfg.Options()
|
|
err = opts.SetDefaults()
|
|
require.NoError(t, err)
|
|
redirectHTTPToHTTPSDeprecation(ctx, logger, inv, cfg)
|
|
require.Equal(t, tc.expected, cfg.RedirectToAccessURL.Value())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsDERPPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
path string
|
|
expected bool
|
|
}{
|
|
//{
|
|
// path: "/derp",
|
|
// expected: true,
|
|
// },
|
|
{
|
|
path: "/derp/",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "/derp/latency-check",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "/derp/latency-check/",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/derptastic",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/api/v2/derp",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "//",
|
|
expected: false,
|
|
},
|
|
}
|
|
for _, tc := range testcases {
|
|
t.Run(tc.path, func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, tc.expected, isDERPPath(tc.path))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsReplicaRelayRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("WithHeader", func(t *testing.T) {
|
|
t.Parallel()
|
|
r, _ := http.NewRequestWithContext(context.Background(), "GET", "/api/experimental/chats/abc/stream", nil)
|
|
r.Header.Set("X-Coder-Relay-Source-Replica", "some-uuid")
|
|
require.True(t, isReplicaRelayRequest(r))
|
|
})
|
|
|
|
t.Run("WithoutHeader", func(t *testing.T) {
|
|
t.Parallel()
|
|
r, _ := http.NewRequestWithContext(context.Background(), "GET", "/api/experimental/chats/abc/stream", nil)
|
|
require.False(t, isReplicaRelayRequest(r))
|
|
})
|
|
|
|
t.Run("EmptyHeader", func(t *testing.T) {
|
|
t.Parallel()
|
|
r, _ := http.NewRequestWithContext(context.Background(), "GET", "/api/experimental/chats/abc/stream", nil)
|
|
r.Header.Set("X-Coder-Relay-Source-Replica", "")
|
|
require.False(t, isReplicaRelayRequest(r))
|
|
})
|
|
}
|
|
|
|
func TestEscapePostgresURLUserInfo(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
input string
|
|
output string
|
|
err error
|
|
}{
|
|
{
|
|
input: "postgres://coder:coder@localhost:5432/coder",
|
|
output: "postgres://coder:coder@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co{der@localhost:5432/coder",
|
|
output: "postgres://coder:co%7Bder@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co:der@localhost:5432/coder",
|
|
output: "postgres://coder:co:der@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co der@localhost:5432/coder",
|
|
output: "postgres://coder:co%20der@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://local host:5432/coder",
|
|
output: "",
|
|
err: xerrors.New("parse postgres url: parse \"postgres://local host:5432/coder\": invalid character \" \" in host name"),
|
|
},
|
|
{
|
|
input: "postgres://coder:co?der@localhost:5432/coder",
|
|
output: "postgres://coder:co%3Fder@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co#der@localhost:5432/coder",
|
|
output: "postgres://coder:co%23der@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
}
|
|
for _, tc := range testcases {
|
|
t.Run(tc.input, func(t *testing.T) {
|
|
t.Parallel()
|
|
o, err := escapePostgresURLUserInfo(tc.input)
|
|
assert.Equal(t, tc.output, o)
|
|
if tc.err != nil {
|
|
require.Error(t, err)
|
|
require.EqualValues(t, tc.err.Error(), err.Error())
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|