mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
8b1705eb65
## Summary Routes chatd model calls backed by concrete AI Provider rows through the in-process aibridge transport by default, with deployment options to use direct provider routing when AI Gateway is disabled or chat AI Gateway routing is disabled. - Splits model routing into common, direct provider, and AI Gateway paths behind a single deployment-mode entry point. - Builds chatd models through explicit request, route, and options data. Active API key attribution is passed explicitly instead of being hidden inside generic model construction. - For AI Gateway BYOK routes, resolves the user's provider key in chatd, forwards it through provider-specific auth headers, and sets `X-Coder-AI-Governance-Token` to the `delegated` marker so aibridge preserves those headers while still stripping Coder-specific metadata. - Keeps central provider credentials and deployment fallback credentials out of forwarded provider auth headers, so AI Gateway central policy remains authoritative. - Redacts delegated provider auth from default string formatting to avoid accidental plaintext logging of user BYOK credentials. - Covers selected chat models, advisor overrides, title and quickgen paths, subagent overrides, computer use model selection, and an integration-style chat turn through the aibridge transport path. - Persists initiating API key IDs on chat and queued user messages, including subagent child messages, and fails closed for AI Gateway-routed model builds without an active key. - Removes unused `api_key_id` indexes while keeping the persistence columns and foreign keys. - Keeps the deployment option available through config and env parsing, but hides it from CLI help and generated docs. - Stabilizes the subagent poll fallback test so background CreateChat processing cannot win the state transition under slower CI environments. ## Tests - `go test ./coderd/x/chatd -run 'TestAIGatewayProviderAuthForUser|TestAIGatewayProviderAuthRedactsFormatting|TestResolveModelRouteForConfigAIGatewayProviderAuth|TestAIGatewayModelForwardsProviderAuth|TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey|TestAwaitSubagentCompletion' -count=1` - `go test ./coderd/aibridged -run 'TestServeHTTP_DelegatedAPIKey|TestServeHTTP_StripCoderToken' -count=1` - `git diff --check HEAD~1..HEAD` - `make lint` > Mux working on behalf of Mike.
1326 lines
39 KiB
Go
1326 lines
39 KiB
Go
package codersdk_test
|
|
|
|
import (
|
|
"bytes"
|
|
"embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gopkg.in/yaml.v3"
|
|
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
type exclusion struct {
|
|
flag bool
|
|
env bool
|
|
yaml bool
|
|
}
|
|
|
|
func TestDeploymentValues_HighlyConfigurable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This test ensures that every deployment option has
|
|
// a corresponding Flag, Env, and YAML name, unless explicitly excluded.
|
|
|
|
excludes := map[string]exclusion{
|
|
// These are used to configure YAML support itself, so
|
|
// they make no sense within the YAML file.
|
|
"Config Path": {
|
|
yaml: true,
|
|
},
|
|
"Write Config": {
|
|
yaml: true,
|
|
env: true,
|
|
},
|
|
// Dangerous values? Not sure we should help users
|
|
// persistent their configuration.
|
|
"DANGEROUS: Allow Path App Sharing": {
|
|
yaml: true,
|
|
},
|
|
"DANGEROUS: Allow Site Owners to Access Path Apps": {
|
|
yaml: true,
|
|
},
|
|
// Secrets
|
|
"Trace Honeycomb API Key": {
|
|
yaml: true,
|
|
},
|
|
"OAuth2 GitHub Client Secret": {
|
|
yaml: true,
|
|
},
|
|
"OIDC Client Secret": {
|
|
yaml: true,
|
|
},
|
|
"Postgres Connection URL": {
|
|
yaml: true,
|
|
},
|
|
"SCIM API Key": {
|
|
yaml: true,
|
|
},
|
|
"External Token Encryption Keys": {
|
|
yaml: true,
|
|
},
|
|
"External Auth Providers": {
|
|
// Technically External Auth Providers can be provided through the env,
|
|
// but bypassing serpent. See cli.ReadExternalAuthProvidersFromEnv.
|
|
flag: true,
|
|
env: true,
|
|
},
|
|
"Provisioner Daemon Pre-shared Key (PSK)": {
|
|
yaml: true,
|
|
},
|
|
"Email Auth: Password": {
|
|
yaml: true,
|
|
},
|
|
"Notifications: Email Auth: Password": {
|
|
yaml: true,
|
|
},
|
|
// We don't want these to be configurable via YAML because they are secrets.
|
|
// However, we do want to allow them to be shown in documentation.
|
|
"AI Gateway OpenAI Key": {
|
|
yaml: true,
|
|
},
|
|
"AI Gateway Anthropic Key": {
|
|
yaml: true,
|
|
},
|
|
"AI Gateway Bedrock Access Key": {
|
|
yaml: true,
|
|
},
|
|
"AI Gateway Bedrock Access Key Secret": {
|
|
yaml: true,
|
|
},
|
|
}
|
|
|
|
set := (&codersdk.DeploymentValues{}).Options()
|
|
for _, opt := range set {
|
|
// These are generally for development, so their configurability is
|
|
// not relevant.
|
|
if opt.Hidden {
|
|
delete(excludes, opt.Name)
|
|
continue
|
|
}
|
|
|
|
if codersdk.IsSecretDeploymentOption(opt) && opt.YAML != "" {
|
|
// Secrets should not be written to YAML and instead should continue
|
|
// to be read from the environment.
|
|
//
|
|
// Unfortunately, secrets are still accepted through flags for
|
|
// legacy purposes. Eventually, we should prevent that.
|
|
t.Errorf("Option %q is a secret but has a YAML name", opt.Name)
|
|
}
|
|
|
|
excluded := excludes[opt.Name]
|
|
switch {
|
|
case opt.YAML == "" && !excluded.yaml:
|
|
t.Errorf("Option %q should have a YAML name", opt.Name)
|
|
case opt.YAML != "" && excluded.yaml:
|
|
t.Errorf("Option %q is excluded but has a YAML name", opt.Name)
|
|
case opt.Flag == "" && !excluded.flag:
|
|
t.Errorf("Option %q should have a flag name", opt.Name)
|
|
case opt.Flag != "" && excluded.flag:
|
|
t.Errorf("Option %q is excluded but has a flag name", opt.Name)
|
|
case opt.Env == "" && !excluded.env:
|
|
t.Errorf("Option %q should have an env name", opt.Name)
|
|
case opt.Env != "" && excluded.env:
|
|
t.Errorf("Option %q is excluded but has an env name", opt.Name)
|
|
}
|
|
|
|
// Also check all env vars are prefixed with CODER_
|
|
const prefix = "CODER_"
|
|
if opt.Env != "" && !strings.HasPrefix(opt.Env, prefix) {
|
|
t.Errorf("Option %q has an env name (%q) that is not prefixed with %s", opt.Name, opt.Env, prefix)
|
|
}
|
|
|
|
delete(excludes, opt.Name)
|
|
}
|
|
|
|
for opt := range excludes {
|
|
t.Errorf("Excluded option %q is not in the deployment config. Remove it?", opt)
|
|
}
|
|
}
|
|
|
|
func TestSSHConfig_ParseOptions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
Name string
|
|
ConfigOptions serpent.StringArray
|
|
ExpectError bool
|
|
Expect map[string]string
|
|
}{
|
|
{
|
|
Name: "Empty",
|
|
ConfigOptions: []string{},
|
|
Expect: map[string]string{},
|
|
},
|
|
{
|
|
Name: "Whitespace",
|
|
ConfigOptions: []string{
|
|
"test value",
|
|
},
|
|
Expect: map[string]string{
|
|
"test": "value",
|
|
},
|
|
},
|
|
{
|
|
Name: "SimpleValueEqual",
|
|
ConfigOptions: []string{
|
|
"test=value",
|
|
},
|
|
Expect: map[string]string{
|
|
"test": "value",
|
|
},
|
|
},
|
|
{
|
|
Name: "SimpleValues",
|
|
ConfigOptions: []string{
|
|
"test=value",
|
|
"foo=bar",
|
|
},
|
|
Expect: map[string]string{
|
|
"test": "value",
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
{
|
|
Name: "ValueWithQuote",
|
|
ConfigOptions: []string{
|
|
"bar=buzz=bazz",
|
|
},
|
|
Expect: map[string]string{
|
|
"bar": "buzz=bazz",
|
|
},
|
|
},
|
|
{
|
|
Name: "NoEquals",
|
|
ConfigOptions: []string{
|
|
"foobar",
|
|
},
|
|
ExpectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range testCases {
|
|
t.Run(tt.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
c := codersdk.SSHConfig{
|
|
SSHConfigOptions: tt.ConfigOptions,
|
|
}
|
|
got, err := c.ParseOptions()
|
|
if tt.ExpectError {
|
|
require.Error(t, err, tt.ConfigOptions.String())
|
|
} else {
|
|
require.NoError(t, err, tt.ConfigOptions.String())
|
|
require.Equalf(t, tt.Expect, got, tt.ConfigOptions.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTimezoneOffsets(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
Name string
|
|
Now time.Time
|
|
Loc *time.Location
|
|
ExpectedOffset int
|
|
}{
|
|
{
|
|
Name: "UTC",
|
|
Loc: time.UTC,
|
|
ExpectedOffset: 0,
|
|
},
|
|
|
|
{
|
|
Name: "Eastern",
|
|
Now: time.Date(2021, 2, 1, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("America/New_York")),
|
|
ExpectedOffset: 5,
|
|
},
|
|
{
|
|
// Daylight savings is on the 14th of March to Nov 7 in 2021
|
|
Name: "EasternDaylightSavings",
|
|
Now: time.Date(2021, 3, 16, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("America/New_York")),
|
|
ExpectedOffset: 4,
|
|
},
|
|
{
|
|
Name: "Central",
|
|
Now: time.Date(2021, 2, 1, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("America/Chicago")),
|
|
ExpectedOffset: 6,
|
|
},
|
|
{
|
|
Name: "CentralDaylightSavings",
|
|
Now: time.Date(2021, 3, 16, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("America/Chicago")),
|
|
ExpectedOffset: 5,
|
|
},
|
|
{
|
|
Name: "Ireland",
|
|
Now: time.Date(2021, 2, 1, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("Europe/Dublin")),
|
|
ExpectedOffset: 0,
|
|
},
|
|
{
|
|
Name: "IrelandDaylightSavings",
|
|
Now: time.Date(2021, 4, 3, 0, 0, 0, 0, time.UTC),
|
|
Loc: must(time.LoadLocation("Europe/Dublin")),
|
|
ExpectedOffset: -1,
|
|
},
|
|
{
|
|
Name: "HalfHourTz",
|
|
Now: time.Date(2024, 1, 20, 6, 0, 0, 0, must(time.LoadLocation("Asia/Yangon"))),
|
|
// This timezone is +6:30, but the function rounds to the nearest hour.
|
|
// This is intentional because our DAUs endpoint only covers 1-hour offsets.
|
|
// If the user is in a non-hour timezone, they get the closest hour bucket.
|
|
Loc: must(time.LoadLocation("Asia/Yangon")),
|
|
ExpectedOffset: -6,
|
|
},
|
|
}
|
|
|
|
for _, c := range testCases {
|
|
t.Run(c.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
offset := codersdk.TimezoneOffsetHourWithTime(c.Now, c.Loc)
|
|
require.Equal(t, c.ExpectedOffset, offset)
|
|
})
|
|
}
|
|
}
|
|
|
|
func must[T any](value T, err error) T {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return value
|
|
}
|
|
|
|
func TestAIGatewayCompatibilityAliases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
options := (&codersdk.DeploymentValues{}).Options()
|
|
byFlag := map[string]serpent.Option{}
|
|
for _, opt := range options {
|
|
if opt.Flag != "" {
|
|
byFlag[opt.Flag] = opt
|
|
}
|
|
}
|
|
|
|
type alias struct {
|
|
old serpent.Option
|
|
new serpent.Option
|
|
}
|
|
var aliases []alias
|
|
for _, opt := range options {
|
|
if !strings.HasPrefix(opt.Flag, "aibridge-") {
|
|
continue
|
|
}
|
|
require.True(t, strings.HasPrefix(opt.Description, "Deprecated:"), "aibridge option %s should have a 'Deprecated:' description", opt.Flag)
|
|
require.Len(t, opt.UseInstead, 1, "aibridge option %s should point to a single replacement", opt.Flag)
|
|
|
|
newOpt, ok := byFlag[opt.UseInstead[0].Flag]
|
|
require.True(t, ok, "aibridge option %s points to unknown flag %s", opt.Flag, opt.UseInstead[0].Flag)
|
|
require.NotEqual(t, opt.Flag, newOpt.Flag, "flag %s shares its flag with the new alias option", opt.Flag)
|
|
require.NotEqual(t, opt.Env, newOpt.Env, "flag %s shares its env with the new alias option", opt.Flag)
|
|
if oldYAML := opt.YAMLPath(); oldYAML != "" {
|
|
require.NotEqual(t, oldYAML, newOpt.YAMLPath(), "flag %s shares its YAML path with the new alias option", opt.Flag)
|
|
} else {
|
|
require.Empty(t, newOpt.YAMLPath(), "flag %s has no YAML path but the new alias option %s does", opt.Flag, newOpt.Flag)
|
|
}
|
|
aliases = append(aliases, alias{old: opt, new: newOpt})
|
|
}
|
|
// Update this count when adding or removing aibridge alias options.
|
|
require.Len(t, aliases, 34, "unexpected number of aibridge alias options")
|
|
|
|
sampleVal := func(opt serpent.Option) any {
|
|
switch opt.Value.Type() {
|
|
case "bool":
|
|
return opt.Default != "true"
|
|
case "int":
|
|
return 7
|
|
case "duration":
|
|
return "2h"
|
|
case "string-array":
|
|
return []string{"10.0.0.0/8", "172.16.0.0/12"}
|
|
default:
|
|
return "alias-value"
|
|
}
|
|
}
|
|
sampleArg := func(opt serpent.Option) string {
|
|
v := sampleVal(opt)
|
|
if arr, ok := v.([]string); ok {
|
|
return strings.Join(arr, ",")
|
|
}
|
|
return fmt.Sprint(v)
|
|
}
|
|
|
|
aiConfFromOpts := func(t *testing.T, apply func(opts serpent.OptionSet) error) codersdk.AIConfig {
|
|
t.Helper()
|
|
dv := &codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
require.NoError(t, opts.SetDefaults())
|
|
require.NoError(t, apply(opts))
|
|
return dv.AI
|
|
}
|
|
|
|
t.Run("FlagParity", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var oldArgs, newArgs []string
|
|
for _, a := range aliases {
|
|
value := sampleArg(a.old)
|
|
oldArgs = append(oldArgs, "--"+a.old.Flag, value)
|
|
newArgs = append(newArgs, "--"+a.new.Flag, value)
|
|
}
|
|
oldAI := aiConfFromOpts(t, func(opts serpent.OptionSet) error {
|
|
return opts.FlagSet().Parse(oldArgs)
|
|
})
|
|
newAI := aiConfFromOpts(t, func(opts serpent.OptionSet) error {
|
|
return opts.FlagSet().Parse(newArgs)
|
|
})
|
|
require.Equal(t, newAI, oldAI)
|
|
})
|
|
|
|
t.Run("EnvParity", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var oldEnv, newEnv []serpent.EnvVar
|
|
for _, a := range aliases {
|
|
value := sampleArg(a.old)
|
|
oldEnv = append(oldEnv, serpent.EnvVar{Name: a.old.Env, Value: value})
|
|
newEnv = append(newEnv, serpent.EnvVar{Name: a.new.Env, Value: value})
|
|
}
|
|
oldAI := aiConfFromOpts(t, func(opts serpent.OptionSet) error {
|
|
return opts.ParseEnv(oldEnv)
|
|
})
|
|
newAI := aiConfFromOpts(t, func(opts serpent.OptionSet) error {
|
|
return opts.ParseEnv(newEnv)
|
|
})
|
|
require.Equal(t, newAI, oldAI)
|
|
})
|
|
|
|
t.Run("YAMLParity", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
setPath := func(doc map[string]any, path string, value any) {
|
|
parts := strings.Split(path, ".")
|
|
for _, field := range parts[:len(parts)-1] {
|
|
next, ok := doc[field].(map[string]any)
|
|
if !ok {
|
|
next = map[string]any{}
|
|
doc[field] = next
|
|
}
|
|
doc = next
|
|
}
|
|
doc[parts[len(parts)-1]] = value
|
|
}
|
|
|
|
oldYAML := map[string]any{}
|
|
newYAML := map[string]any{}
|
|
for _, a := range aliases {
|
|
oldPath := a.old.YAMLPath()
|
|
newPath := a.new.YAMLPath()
|
|
if oldPath == "" {
|
|
require.Empty(t, newPath)
|
|
continue
|
|
}
|
|
require.NotEmpty(t, newPath, "new flag %s has no YAML path", a.old.Flag)
|
|
|
|
value := sampleVal(a.old)
|
|
setPath(oldYAML, oldPath, value)
|
|
setPath(newYAML, newPath, value)
|
|
}
|
|
|
|
parse := func(doc map[string]any) codersdk.AIConfig {
|
|
var node yaml.Node
|
|
require.NoError(t, node.Encode(doc))
|
|
return aiConfFromOpts(t, func(opts serpent.OptionSet) error {
|
|
return opts.UnmarshalYAML(&node)
|
|
})
|
|
}
|
|
|
|
require.Equal(t, parse(newYAML), parse(oldYAML))
|
|
})
|
|
}
|
|
|
|
func TestDeploymentValues_Validate_RefreshLifetime(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mk := func(access, refresh time.Duration) *codersdk.DeploymentValues {
|
|
dv := &codersdk.DeploymentValues{}
|
|
dv.Sessions.DefaultDuration = serpent.Duration(access)
|
|
dv.Sessions.RefreshDefaultDuration = serpent.Duration(refresh)
|
|
return dv
|
|
}
|
|
|
|
t.Run("EqualDurations_Error", func(t *testing.T) {
|
|
t.Parallel()
|
|
dv := mk(1*time.Hour, 1*time.Hour)
|
|
err := dv.Validate()
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "must be strictly greater")
|
|
})
|
|
|
|
t.Run("RefreshShorter_Error", func(t *testing.T) {
|
|
t.Parallel()
|
|
dv := mk(2*time.Hour, 1*time.Hour)
|
|
err := dv.Validate()
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "must be strictly greater")
|
|
})
|
|
|
|
t.Run("RefreshZero_Error", func(t *testing.T) {
|
|
t.Parallel()
|
|
dv := mk(1*time.Hour, 0)
|
|
err := dv.Validate()
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "must be strictly greater")
|
|
})
|
|
|
|
t.Run("AccessUninitialized_Error", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Access duration is zero (uninitialized); refresh is valid.
|
|
dv := mk(0, 48*time.Hour)
|
|
err := dv.Validate()
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "developer error: sessions configuration appears uninitialized")
|
|
})
|
|
|
|
t.Run("RefreshLonger_OK", func(t *testing.T) {
|
|
t.Parallel()
|
|
dv := mk(1*time.Hour, 48*time.Hour)
|
|
err := dv.Validate()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestDeploymentValues_DurationFormatNanoseconds(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
set := (&codersdk.DeploymentValues{}).Options()
|
|
for _, s := range set {
|
|
if s.Value.Type() != "duration" {
|
|
continue
|
|
}
|
|
// Just make sure the annotation is set.
|
|
// If someone wants to not format a duration, they can
|
|
// explicitly set the annotation to false.
|
|
if s.Annotations.IsSet("format_duration") {
|
|
continue
|
|
}
|
|
t.Logf("Option %q is a duration but does not have the format_duration annotation.", s.Name)
|
|
t.Log("To fix this, add the following to the option declaration:")
|
|
t.Log(`Annotations: serpent.Annotations{}.Mark(annotationFormatDurationNS, "true"),`)
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
//go:embed testdata/*
|
|
var testData embed.FS
|
|
|
|
func TestExternalAuthYAMLConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if runtime.GOOS == "windows" {
|
|
// The windows marshal function uses different line endings.
|
|
// Not worth the effort getting this to work on windows.
|
|
t.SkipNow()
|
|
}
|
|
|
|
file := func(t *testing.T, name string) string {
|
|
data, err := testData.ReadFile(fmt.Sprintf("testdata/%s", name))
|
|
require.NoError(t, err, "read testdata file %q", name)
|
|
return string(data)
|
|
}
|
|
githubCfg := codersdk.ExternalAuthConfig{
|
|
Type: "github",
|
|
ClientID: "client_id",
|
|
ClientSecret: "client_secret",
|
|
ID: "id",
|
|
AuthURL: "https://example.com/auth",
|
|
TokenURL: "https://example.com/token",
|
|
ValidateURL: "https://example.com/validate",
|
|
RevokeURL: "https://example.com/revoke",
|
|
AppInstallURL: "https://example.com/install",
|
|
AppInstallationsURL: "https://example.com/installations",
|
|
NoRefresh: true,
|
|
Scopes: []string{"user:email", "read:org"},
|
|
ExtraTokenKeys: []string{"extra", "token"},
|
|
DeviceFlow: true,
|
|
DeviceCodeURL: "https://example.com/device",
|
|
Regex: "^https://example.com/.*$",
|
|
DisplayName: "GitHub",
|
|
DisplayIcon: "/static/icons/github.svg",
|
|
MCPURL: "https://api.githubcopilot.com/mcp/",
|
|
MCPToolAllowRegex: ".*",
|
|
MCPToolDenyRegex: "create_gist",
|
|
CodeChallengeMethodsSupported: []string{"S256"},
|
|
}
|
|
|
|
// Input the github section twice for testing a slice of configs.
|
|
inputYAML := func() string {
|
|
f := file(t, "githubcfg.yaml")
|
|
lines := strings.SplitN(f, "\n", 2)
|
|
// Append github config twice
|
|
return f + lines[1]
|
|
}()
|
|
|
|
expected := []codersdk.ExternalAuthConfig{
|
|
githubCfg, githubCfg,
|
|
}
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
// replace any tabs with the proper space indentation
|
|
inputYAML = strings.ReplaceAll(inputYAML, "\t", " ")
|
|
|
|
// This is the order things are done in the cli, so just
|
|
// keep it the same.
|
|
var n yaml.Node
|
|
err := yaml.Unmarshal([]byte(inputYAML), &n)
|
|
require.NoError(t, err)
|
|
|
|
err = n.Decode(&opts)
|
|
require.NoError(t, err)
|
|
require.ElementsMatchf(t, expected, dv.ExternalAuthConfigs.Value, "from yaml")
|
|
|
|
var out bytes.Buffer
|
|
enc := yaml.NewEncoder(&out)
|
|
enc.SetIndent(2)
|
|
err = enc.Encode(dv.ExternalAuthConfigs)
|
|
require.NoError(t, err)
|
|
|
|
// Because we only marshal the 1 section, the correct section name is not applied.
|
|
output := strings.Replace(out.String(), "value:", "externalAuthProviders:", 1)
|
|
require.Equal(t, inputYAML, output, "re-marshaled is the same as input")
|
|
}
|
|
|
|
func TestFeatureComparison(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
Name string
|
|
A codersdk.Feature
|
|
B codersdk.Feature
|
|
Expected int
|
|
}{
|
|
{
|
|
Name: "Empty",
|
|
Expected: 0,
|
|
},
|
|
// Entitlement check
|
|
// Entitled
|
|
{
|
|
Name: "EntitledVsGracePeriod",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "EntitledVsGracePeriodLimits",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled},
|
|
// Entitled should still win here
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod, Limit: ptr.Ref[int64](100), Actual: ptr.Ref[int64](50)},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "EntitledVsNotEntitled",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementNotEntitled},
|
|
Expected: 3,
|
|
},
|
|
{
|
|
Name: "EntitledVsUnknown",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled},
|
|
B: codersdk.Feature{Entitlement: ""},
|
|
Expected: 4,
|
|
},
|
|
// GracePeriod
|
|
{
|
|
Name: "GracefulVsNotEntitled",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementNotEntitled},
|
|
Expected: 2,
|
|
},
|
|
{
|
|
Name: "GracefulVsUnknown",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod},
|
|
B: codersdk.Feature{Entitlement: ""},
|
|
Expected: 3,
|
|
},
|
|
// NotEntitled
|
|
{
|
|
Name: "NotEntitledVsUnknown",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementNotEntitled},
|
|
B: codersdk.Feature{Entitlement: ""},
|
|
Expected: 1,
|
|
},
|
|
// --
|
|
{
|
|
Name: "EntitledVsGracePeriodCapable",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref[int64](100), Actual: ptr.Ref[int64](200)},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod, Limit: ptr.Ref[int64](300), Actual: ptr.Ref[int64](200)},
|
|
Expected: -1,
|
|
},
|
|
// UserLimits
|
|
{
|
|
// Tests an exceeded limit that is entitled vs a graceful limit that
|
|
// is not exceeded. This is the edge case that we should use the graceful period
|
|
// instead of the entitled.
|
|
Name: "UserLimitExceeded",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(200))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod, Limit: ptr.Ref(int64(300)), Actual: ptr.Ref(int64(200))},
|
|
Expected: -1,
|
|
},
|
|
{
|
|
Name: "UserLimitExceededNoEntitled",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(200))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementNotEntitled, Limit: ptr.Ref(int64(300)), Actual: ptr.Ref(int64(200))},
|
|
Expected: 3,
|
|
},
|
|
{
|
|
Name: "HigherLimit",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(110)), Actual: ptr.Ref(int64(200))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(200))},
|
|
Expected: 10, // Diff in the limit #
|
|
},
|
|
{
|
|
Name: "HigherActual",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(300))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(200))},
|
|
Expected: 100, // Diff in the actual #
|
|
},
|
|
{
|
|
Name: "LimitExists",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(50))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: nil, Actual: ptr.Ref(int64(200))},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "LimitExistsGrace",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(50))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementGracePeriod, Limit: nil, Actual: ptr.Ref(int64(200))},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "ActualExists",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(50))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: nil},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "NotNils",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(50))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: nil, Actual: nil},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "EnabledVsDisabled",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Enabled: true, Limit: ptr.Ref(int64(300)), Actual: ptr.Ref(int64(200))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(300)), Actual: ptr.Ref(int64(200))},
|
|
Expected: 1,
|
|
},
|
|
{
|
|
Name: "NotNils",
|
|
A: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: ptr.Ref(int64(100)), Actual: ptr.Ref(int64(50))},
|
|
B: codersdk.Feature{Entitlement: codersdk.EntitlementEntitled, Limit: nil, Actual: nil},
|
|
Expected: 1,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := tc.A.Compare(tc.B)
|
|
logIt := !assert.Equal(t, tc.Expected, r)
|
|
|
|
// Comparisons should be like addition. A - B = -1 * (B - A)
|
|
r = tc.B.Compare(tc.A)
|
|
logIt = logIt || !assert.Equalf(t, tc.Expected*-1, r, "the inverse comparison should also be true")
|
|
if logIt {
|
|
ad, _ := json.Marshal(tc.A)
|
|
bd, _ := json.Marshal(tc.B)
|
|
t.Logf("a = %s\nb = %s", ad, bd)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestPremiumSuperSet tests that the "premium" feature set is a superset of the
|
|
// "enterprise" feature set.
|
|
func TestPremiumSuperSet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
enterprise := codersdk.FeatureSetEnterprise
|
|
premium := codersdk.FeatureSetPremium
|
|
|
|
// Premium > Enterprise
|
|
require.Greater(t, len(premium.Features()), len(enterprise.Features()), "premium should have more features than enterprise")
|
|
|
|
// Premium ⊃ Enterprise
|
|
require.Subset(t, premium.Features(), enterprise.Features(), "premium should be a superset of enterprise. If this fails, update the premium feature set to include all enterprise features.")
|
|
|
|
// Premium = All Features EXCEPT limit-based features.
|
|
// TODO: In future release, also exclude addon features (f.IsAddonFeature()).
|
|
expectedPremiumFeatures := []codersdk.FeatureName{}
|
|
for _, feature := range codersdk.FeatureNames {
|
|
if feature.UsesLimit() {
|
|
continue
|
|
}
|
|
expectedPremiumFeatures = append(expectedPremiumFeatures, feature)
|
|
}
|
|
require.NotEmpty(t, expectedPremiumFeatures, "expectedPremiumFeatures should not be empty")
|
|
require.ElementsMatch(t, premium.Features(), expectedPremiumFeatures, "premium should contain all features except usage limit features")
|
|
|
|
// This check exists because if you misuse the slices.Delete, you can end up
|
|
// with zero'd values.
|
|
require.NotContains(t, enterprise.Features(), "", "enterprise should not contain empty string")
|
|
require.NotContains(t, premium.Features(), "", "premium should not contain empty string")
|
|
}
|
|
|
|
func TestNotificationsCanBeDisabled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
expectNotificationsEnabled bool
|
|
environment []serpent.EnvVar
|
|
}{
|
|
{
|
|
name: "NoDeliveryMethodSet",
|
|
environment: []serpent.EnvVar{},
|
|
expectNotificationsEnabled: false,
|
|
},
|
|
{
|
|
name: "SMTP_DeliveryMethodSet",
|
|
environment: []serpent.EnvVar{
|
|
{
|
|
Name: "CODER_EMAIL_SMARTHOST",
|
|
Value: "localhost:587",
|
|
},
|
|
},
|
|
expectNotificationsEnabled: true,
|
|
},
|
|
{
|
|
name: "Webhook_DeliveryMethodSet",
|
|
environment: []serpent.EnvVar{
|
|
{
|
|
Name: "CODER_NOTIFICATIONS_WEBHOOK_ENDPOINT",
|
|
Value: "https://example.com/webhook",
|
|
},
|
|
},
|
|
expectNotificationsEnabled: true,
|
|
},
|
|
{
|
|
name: "WebhookAndSMTP_DeliveryMethodSet",
|
|
environment: []serpent.EnvVar{
|
|
{
|
|
Name: "CODER_NOTIFICATIONS_WEBHOOK_ENDPOINT",
|
|
Value: "https://example.com/webhook",
|
|
},
|
|
{
|
|
Name: "CODER_EMAIL_SMARTHOST",
|
|
Value: "localhost:587",
|
|
},
|
|
},
|
|
expectNotificationsEnabled: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
err := opts.ParseEnv(tt.environment)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, tt.expectNotificationsEnabled, dv.Notifications.Enabled())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRetentionConfigParsing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
environment []serpent.EnvVar
|
|
expectedAuditLogs time.Duration
|
|
expectedConnectionLogs time.Duration
|
|
expectedAPIKeys time.Duration
|
|
}{
|
|
{
|
|
name: "Defaults",
|
|
environment: []serpent.EnvVar{},
|
|
expectedAuditLogs: 0,
|
|
expectedConnectionLogs: 0,
|
|
expectedAPIKeys: 7 * 24 * time.Hour, // 7 days default
|
|
},
|
|
{
|
|
name: "IndividualRetentionSet",
|
|
environment: []serpent.EnvVar{
|
|
{Name: "CODER_AUDIT_LOGS_RETENTION", Value: "30d"},
|
|
{Name: "CODER_CONNECTION_LOGS_RETENTION", Value: "60d"},
|
|
{Name: "CODER_API_KEYS_RETENTION", Value: "14d"},
|
|
},
|
|
expectedAuditLogs: 30 * 24 * time.Hour,
|
|
expectedConnectionLogs: 60 * 24 * time.Hour,
|
|
expectedAPIKeys: 14 * 24 * time.Hour,
|
|
},
|
|
{
|
|
name: "AllRetentionSet",
|
|
environment: []serpent.EnvVar{
|
|
{Name: "CODER_AUDIT_LOGS_RETENTION", Value: "365d"},
|
|
{Name: "CODER_CONNECTION_LOGS_RETENTION", Value: "30d"},
|
|
{Name: "CODER_API_KEYS_RETENTION", Value: "0"},
|
|
},
|
|
expectedAuditLogs: 365 * 24 * time.Hour,
|
|
expectedConnectionLogs: 30 * 24 * time.Hour,
|
|
expectedAPIKeys: 0, // Explicitly disabled
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
err := opts.SetDefaults()
|
|
require.NoError(t, err)
|
|
|
|
err = opts.ParseEnv(tt.environment)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tt.expectedAuditLogs, dv.Retention.AuditLogs.Value(), "audit logs retention mismatch")
|
|
assert.Equal(t, tt.expectedConnectionLogs, dv.Retention.ConnectionLogs.Value(), "connection logs retention mismatch")
|
|
assert.Equal(t, tt.expectedAPIKeys, dv.Retention.APIKeys.Value(), "api keys retention mismatch")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChatAIGatewayRoutingEnabledDefault(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
require.NoError(t, opts.SetDefaults())
|
|
require.True(t, dv.AI.Chat.AIGatewayRoutingEnabled.Value())
|
|
}
|
|
|
|
func TestAIBudgetConfigParsing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Defaults", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
require.NoError(t, opts.SetDefaults())
|
|
|
|
assert.Equal(t, string(codersdk.AIBudgetPolicyHighest), dv.AI.BridgeConfig.BudgetPolicy)
|
|
assert.Equal(t, string(codersdk.AIBudgetPeriodMonth), dv.AI.BridgeConfig.BudgetPeriod)
|
|
})
|
|
|
|
t.Run("AcceptsSupportedValues", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
require.NoError(t, opts.SetDefaults())
|
|
require.NoError(t, opts.ParseEnv([]serpent.EnvVar{
|
|
{Name: "CODER_AI_BUDGET_POLICY", Value: string(codersdk.AIBudgetPolicyHighest)},
|
|
{Name: "CODER_AI_BUDGET_PERIOD", Value: string(codersdk.AIBudgetPeriodMonth)},
|
|
}))
|
|
|
|
assert.Equal(t, string(codersdk.AIBudgetPolicyHighest), dv.AI.BridgeConfig.BudgetPolicy)
|
|
assert.Equal(t, string(codersdk.AIBudgetPeriodMonth), dv.AI.BridgeConfig.BudgetPeriod)
|
|
})
|
|
|
|
t.Run("RejectsUnsupportedPolicy", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
require.NoError(t, opts.SetDefaults())
|
|
err := opts.ParseEnv([]serpent.EnvVar{
|
|
{Name: "CODER_AI_BUDGET_POLICY", Value: "invalid"},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid choice")
|
|
})
|
|
|
|
t.Run("RejectsUnsupportedPeriod", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dv := codersdk.DeploymentValues{}
|
|
opts := dv.Options()
|
|
|
|
require.NoError(t, opts.SetDefaults())
|
|
err := opts.ParseEnv([]serpent.EnvVar{
|
|
{Name: "CODER_AI_BUDGET_PERIOD", Value: "invalid"},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid choice")
|
|
})
|
|
}
|
|
|
|
func TestComputeMaxIdleConns(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
maxOpen int
|
|
configuredIdle string
|
|
expectedIdle int
|
|
expectError bool
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "auto_default_10_open",
|
|
maxOpen: 10,
|
|
configuredIdle: "auto",
|
|
expectedIdle: 3, // 10/3 = 3
|
|
},
|
|
{
|
|
name: "auto_with_whitespace",
|
|
maxOpen: 10,
|
|
configuredIdle: " auto ",
|
|
expectedIdle: 3, // 10/3 = 3
|
|
},
|
|
{
|
|
name: "auto_30_open",
|
|
maxOpen: 30,
|
|
configuredIdle: "auto",
|
|
expectedIdle: 10, // 30/3 = 10
|
|
},
|
|
{
|
|
name: "auto_minimum_1",
|
|
maxOpen: 1,
|
|
configuredIdle: "auto",
|
|
expectedIdle: 1, // 1/3 = 0, but minimum is 1
|
|
},
|
|
{
|
|
name: "auto_minimum_2_open",
|
|
maxOpen: 2,
|
|
configuredIdle: "auto",
|
|
expectedIdle: 1, // 2/3 = 0, but minimum is 1
|
|
},
|
|
{
|
|
name: "auto_3_open",
|
|
maxOpen: 3,
|
|
configuredIdle: "auto",
|
|
expectedIdle: 1, // 3/3 = 1
|
|
},
|
|
{
|
|
name: "explicit_equal_to_max",
|
|
maxOpen: 10,
|
|
configuredIdle: "10",
|
|
expectedIdle: 10,
|
|
},
|
|
{
|
|
name: "explicit_less_than_max",
|
|
maxOpen: 10,
|
|
configuredIdle: "5",
|
|
expectedIdle: 5,
|
|
},
|
|
{
|
|
name: "explicit_with_whitespace",
|
|
maxOpen: 10,
|
|
configuredIdle: " 5 ",
|
|
expectedIdle: 5,
|
|
},
|
|
{
|
|
name: "explicit_0",
|
|
maxOpen: 10,
|
|
configuredIdle: "0",
|
|
expectedIdle: 0,
|
|
},
|
|
{
|
|
name: "error_exceeds_max",
|
|
maxOpen: 10,
|
|
configuredIdle: "15",
|
|
expectError: true,
|
|
errorContains: "cannot exceed",
|
|
},
|
|
{
|
|
name: "error_exceeds_max_by_1",
|
|
maxOpen: 10,
|
|
configuredIdle: "11",
|
|
expectError: true,
|
|
errorContains: "cannot exceed",
|
|
},
|
|
{
|
|
name: "error_invalid_string",
|
|
maxOpen: 10,
|
|
configuredIdle: "invalid",
|
|
expectError: true,
|
|
errorContains: "must be \"auto\" or >= 0",
|
|
},
|
|
{
|
|
name: "error_negative",
|
|
maxOpen: 10,
|
|
configuredIdle: "-1",
|
|
expectError: true,
|
|
errorContains: "must be \"auto\" or >= 0",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
result, err := codersdk.ComputeMaxIdleConns(tt.maxOpen, tt.configuredIdle)
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.errorContains)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Equal(t, tt.expectedIdle, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPCookieConfigMiddleware(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Realistic cookies that are always present in production.
|
|
// These cookies are added to every test.
|
|
baseCookies := []*http.Cookie{
|
|
{Name: "_ga", Value: "GA1.1.661026807.1770083336"},
|
|
{Name: "_ga_G0Q1B9GRC0", Value: "GS2.1.s1771343727$o49$g1$t1771343993$j48$l0$h0"},
|
|
{Name: "csrf_token", Value: "gDiKk8GjTM2iCUHAPfN9GlC+DGjzAprlLi2vJ+5TBU0="},
|
|
}
|
|
|
|
cases := []struct {
|
|
name string
|
|
cfg codersdk.HTTPCookieConfig
|
|
extraCookies []*http.Cookie
|
|
expectedCookies map[string]string // cookie name -> value that handler should see
|
|
expectedDeleted []string // if any cookies are supposed to be deleted via Set-Cookie
|
|
}{
|
|
{
|
|
name: "Disabled_PassesThrough",
|
|
cfg: codersdk.HTTPCookieConfig{},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: codersdk.SessionTokenCookie, Value: "token123"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
codersdk.SessionTokenCookie: "token123",
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_StripsPrefixFromCookie",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "token123"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
codersdk.SessionTokenCookie: "token123",
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_DeletesUnprefixedCookie",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
// Unprefixed cookie that should be in the "to prefix" list.
|
|
{Name: codersdk.SessionTokenCookie, Value: "unprefixed-token"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
// Session token should NOT be present - it was deleted.
|
|
},
|
|
expectedDeleted: []string{codersdk.SessionTokenCookie},
|
|
},
|
|
{
|
|
name: "Enabled_BothPrefixedAndUnprefixed",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
// Browser might send both during migration.
|
|
{Name: codersdk.SessionTokenCookie, Value: "unprefixed-token"},
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "prefixed-token"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
codersdk.SessionTokenCookie: "prefixed-token", // Prefixed wins.
|
|
},
|
|
expectedDeleted: []string{codersdk.SessionTokenCookie},
|
|
},
|
|
{
|
|
name: "Enabled_MultiplePrefixedCookies",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "session"},
|
|
{Name: "__Host-SomeOtherCookie", Value: "other-cookie"},
|
|
{Name: "__Host-Santa", Value: "santa"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
codersdk.SessionTokenCookie: "session",
|
|
"__Host-SomeOtherCookie": "other-cookie",
|
|
"__Host-Santa": "santa",
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_UnrelatedCookiesUnchanged",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "custom_cookie", Value: "custom-value"},
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "session"},
|
|
{Name: "__Host-foobar", Value: "do-not-change-me"},
|
|
},
|
|
expectedCookies: map[string]string{
|
|
"custom_cookie": "custom-value",
|
|
codersdk.SessionTokenCookie: "session",
|
|
"__Host-foobar": "do-not-change-me",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var handlerCookies []*http.Cookie
|
|
handler := tc.cfg.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCookies = r.Cookies()
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
for _, c := range baseCookies {
|
|
req.AddCookie(c)
|
|
}
|
|
for _, c := range tc.extraCookies {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
rw := httptest.NewRecorder()
|
|
handler.ServeHTTP(rw, req)
|
|
|
|
// Verify cookies seen by handler.
|
|
gotCookies := make(map[string]string)
|
|
for _, c := range handlerCookies {
|
|
gotCookies[c.Name] = c.Value
|
|
}
|
|
|
|
for _, v := range baseCookies {
|
|
tc.expectedCookies[v.Name] = v.Value
|
|
}
|
|
assert.Equal(t, tc.expectedCookies, gotCookies)
|
|
|
|
// Verify Set-Cookie header for deletion.
|
|
setCookies := rw.Result().Cookies()
|
|
if len(tc.expectedDeleted) > 0 {
|
|
assert.NotEmpty(t, setCookies, "expected Set-Cookie header for cookie deletion")
|
|
expDel := make(map[string]struct{})
|
|
for _, name := range tc.expectedDeleted {
|
|
expDel[name] = struct{}{}
|
|
}
|
|
// Verify it's a deletion (MaxAge < 0).
|
|
for _, c := range setCookies {
|
|
assert.Less(t, c.MaxAge, 0, "Set-Cookie should have MaxAge < 0 for deletion")
|
|
delete(expDel, c.Name)
|
|
}
|
|
require.Empty(t, expDel, "expected Set-Cookie header for deletion")
|
|
} else {
|
|
assert.Empty(t, setCookies, "did not expect Set-Cookie header")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func BenchmarkHTTPCookieConfigMiddleware(b *testing.B) {
|
|
noop := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
|
|
// Realistic cookies that are always present in production.
|
|
baseCookies := []*http.Cookie{
|
|
{Name: "_ga", Value: "GA1.1.661026807.1770083336"},
|
|
{Name: "_ga_G0Q1B9GRC0", Value: "GS2.1.s1771343727$o49$g1$t1771343993$j48$l0$h0"},
|
|
{Name: "csrf_token", Value: "gDiKk8GjTM2iCUHAPfN9GlC+DGjzAprlLi2vJ+5TBU0="},
|
|
}
|
|
|
|
cases := []struct {
|
|
name string
|
|
cfg codersdk.HTTPCookieConfig
|
|
extraCookies []*http.Cookie
|
|
}{
|
|
{
|
|
name: "Disabled",
|
|
cfg: codersdk.HTTPCookieConfig{},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: codersdk.SessionTokenCookie, Value: "KybJV9fNul-u11vlll9wiF6eLQDxBVucD"},
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_NoPrefixedCookies",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: codersdk.SessionTokenCookie, Value: "KybJV9fNul-u11vlll9wiF6eLQDxBVucD"},
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_WithPrefixedCookie",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "KybJV9fNul-u11vlll9wiF6eLQDxBVucD"},
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_MultiplePrefixedCookies",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "KybJV9fNul-u11vlll9wiF6eLQDxBVucD"},
|
|
{Name: "__Host-" + codersdk.PathAppSessionTokenCookie, Value: "xyz123"},
|
|
{Name: "__Host-" + codersdk.SubdomainAppSessionTokenCookie, Value: "abc456"},
|
|
{Name: "__Host-" + "foobar", Value: "do-not-change-me"},
|
|
},
|
|
},
|
|
{
|
|
name: "Enabled_NonSessionPrefixedCookies",
|
|
cfg: codersdk.HTTPCookieConfig{EnableHostPrefix: true},
|
|
extraCookies: []*http.Cookie{
|
|
{Name: "__Host-" + codersdk.SessionTokenCookie, Value: "KybJV9fNul-u11vlll9wiF6eLQDxBVucD"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
b.Run(tc.name, func(b *testing.B) {
|
|
handler := tc.cfg.Middleware(noop)
|
|
rw := httptest.NewRecorder()
|
|
|
|
allCookies := make([]*http.Cookie, 1, len(baseCookies))
|
|
copy(allCookies, baseCookies)
|
|
// Combine base cookies with test-specific cookies.
|
|
allCookies = append(allCookies, tc.extraCookies...)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
for _, c := range allCookies {
|
|
req.AddCookie(c)
|
|
}
|
|
handler.ServeHTTP(rw, req)
|
|
}
|
|
})
|
|
}
|
|
}
|