feat: allow bypassing current CORS magic based on template config (#18706)

Solves https://github.com/coder/coder/issues/15096

This is a slight rework/refactor of the earlier PRs from @dannykopping
and @Emyrk:
- https://github.com/coder/coder/pull/15669
- https://github.com/coder/coder/pull/15684
- https://github.com/coder/coder/pull/17596

Rather than having a per-app CORS behaviour setting and additionally a
template level setting for ports, this PR adds a single template level
CORS behaviour setting that is then used by all apps/ports for
workspaces created from that template.

The main changes are in `proxy.go` and `request.go` to:
a) get the CORS behaviour setting from the template
b) have `HandleSubdomain` bypass the CORS middleware handler if the
selected behaviour is `passthru`
c) in `proxyWorkspaceApp`, do not modify the response if the selected
behaviour is `passthru`

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added support for configuring CORS behavior ("simple" or "passthru")
at the template level for all shared ports.
* Introduced a new "CORS Behavior" setting in the template creation and
settings forms.
* API endpoints and responses now include the optional `cors_behavior`
property for templates.
* Workspace apps and proxy now honor the specified CORS behavior,
enabling conditional CORS middleware application.
* Enhanced workspace app tests with comprehensive scenarios covering
CORS behaviors and authentication states.

* **Bug Fixes**
  * None.

* **Documentation**
* Updated API and admin documentation to describe the new
`cors_behavior` property and its usage.
* Added examples and schema references for CORS behavior in relevant API
docs.

* **Tests**
* Extended automated tests to cover different CORS behavior scenarios
for templates and workspace apps.

* **Chores**
* Updated audit logging to track changes to the `cors_behavior` field on
templates.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Callum Styan <callumstyan@gmail.com>
This commit is contained in:
Callum Styan
2025-07-30 13:42:39 -07:00
committed by GitHub
parent 96e32d60a2
commit ffbfaf2a6f
36 changed files with 1149 additions and 108 deletions
+22
View File
@@ -11467,6 +11467,17 @@ const docTemplate = `{
"BuildReasonJetbrainsConnection"
]
},
"codersdk.CORSBehavior": {
"type": "string",
"enum": [
"simple",
"passthru"
],
"x-enum-varnames": [
"CORSBehaviorSimple",
"CORSBehaviorPassthru"
]
},
"codersdk.ChangePasswordWithOneTimePasscodeRequest": {
"type": "object",
"required": [
@@ -11808,6 +11819,14 @@ const docTemplate = `{
}
]
},
"cors_behavior": {
"description": "CORSBehavior allows optionally specifying the CORS behavior for all shared ports.",
"allOf": [
{
"$ref": "#/definitions/codersdk.CORSBehavior"
}
]
},
"default_ttl_ms": {
"description": "DefaultTTLMillis allows optionally specifying the default TTL\nfor all workspaces created from this template.",
"type": "integer"
@@ -16215,6 +16234,9 @@ const docTemplate = `{
"build_time_stats": {
"$ref": "#/definitions/codersdk.TemplateBuildTimeStats"
},
"cors_behavior": {
"$ref": "#/definitions/codersdk.CORSBehavior"
},
"created_at": {
"type": "string",
"format": "date-time"
+16
View File
@@ -10202,6 +10202,11 @@
"BuildReasonJetbrainsConnection"
]
},
"codersdk.CORSBehavior": {
"type": "string",
"enum": ["simple", "passthru"],
"x-enum-varnames": ["CORSBehaviorSimple", "CORSBehaviorPassthru"]
},
"codersdk.ChangePasswordWithOneTimePasscodeRequest": {
"type": "object",
"required": ["email", "one_time_passcode", "password"],
@@ -10525,6 +10530,14 @@
}
]
},
"cors_behavior": {
"description": "CORSBehavior allows optionally specifying the CORS behavior for all shared ports.",
"allOf": [
{
"$ref": "#/definitions/codersdk.CORSBehavior"
}
]
},
"default_ttl_ms": {
"description": "DefaultTTLMillis allows optionally specifying the default TTL\nfor all workspaces created from this template.",
"type": "integer"
@@ -14774,6 +14787,9 @@
"build_time_stats": {
"$ref": "#/definitions/codersdk.TemplateBuildTimeStats"
},
"cors_behavior": {
"$ref": "#/definitions/codersdk.CORSBehavior"
},
"created_at": {
"type": "string",
"format": "date-time"
+2
View File
@@ -1462,6 +1462,7 @@ func (s *MethodTestSuite) TestTemplate() {
Provisioner: "echo",
OrganizationID: orgID,
MaxPortSharingLevel: database.AppSharingLevelOwner,
CorsBehavior: database.CorsBehaviorSimple,
}).Asserts(rbac.ResourceTemplate.InOrg(orgID), policy.ActionCreate)
}))
s.Run("InsertTemplateVersion", s.Subtest(func(db database.Store, check *expects) {
@@ -1582,6 +1583,7 @@ func (s *MethodTestSuite) TestTemplate() {
check.Args(database.UpdateTemplateMetaByIDParams{
ID: t1.ID,
MaxPortSharingLevel: "owner",
CorsBehavior: database.CorsBehaviorSimple,
}).Asserts(t1, policy.ActionUpdate)
}))
s.Run("UpdateTemplateVersionByID", s.Subtest(func(db database.Store, check *expects) {
+1
View File
@@ -148,6 +148,7 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs,
MaxPortSharingLevel: takeFirst(seed.MaxPortSharingLevel, database.AppSharingLevelOwner),
UseClassicParameterFlow: takeFirst(seed.UseClassicParameterFlow, false),
CorsBehavior: takeFirst(seed.CorsBehavior, database.CorsBehaviorSimple),
})
require.NoError(t, err, "insert template")
+8 -1
View File
@@ -73,6 +73,11 @@ CREATE TYPE connection_type AS ENUM (
'port_forwarding'
);
CREATE TYPE cors_behavior AS ENUM (
'simple',
'passthru'
);
CREATE TYPE crypto_key_feature AS ENUM (
'workspace_apps_token',
'workspace_apps_api_key',
@@ -1750,7 +1755,8 @@ CREATE TABLE templates (
deprecated text DEFAULT ''::text NOT NULL,
activity_bump bigint DEFAULT '3600000000000'::bigint NOT NULL,
max_port_sharing_level app_sharing_level DEFAULT 'owner'::app_sharing_level NOT NULL,
use_classic_parameter_flow boolean DEFAULT false NOT NULL
use_classic_parameter_flow boolean DEFAULT false NOT NULL,
cors_behavior cors_behavior DEFAULT 'simple'::cors_behavior NOT NULL
);
COMMENT ON COLUMN templates.default_ttl IS 'The default duration for autostop for workspaces created from this template.';
@@ -1803,6 +1809,7 @@ CREATE VIEW template_with_names AS
templates.activity_bump,
templates.max_port_sharing_level,
templates.use_classic_parameter_flow,
templates.cors_behavior,
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
@@ -0,0 +1,46 @@
DROP VIEW IF EXISTS template_with_names;
CREATE VIEW template_with_names AS
SELECT templates.id,
templates.created_at,
templates.updated_at,
templates.organization_id,
templates.deleted,
templates.name,
templates.provisioner,
templates.active_version_id,
templates.description,
templates.default_ttl,
templates.created_by,
templates.icon,
templates.user_acl,
templates.group_acl,
templates.display_name,
templates.allow_user_cancel_workspace_jobs,
templates.allow_user_autostart,
templates.allow_user_autostop,
templates.failure_ttl,
templates.time_til_dormant,
templates.time_til_dormant_autodelete,
templates.autostop_requirement_days_of_week,
templates.autostop_requirement_weeks,
templates.autostart_block_days_of_week,
templates.require_active_version,
templates.deprecated,
templates.activity_bump,
templates.max_port_sharing_level,
templates.use_classic_parameter_flow,
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
COALESCE(organizations.name, ''::text) AS organization_name,
COALESCE(organizations.display_name, ''::text) AS organization_display_name,
COALESCE(organizations.icon, ''::text) AS organization_icon
FROM ((templates
LEFT JOIN visible_users ON ((templates.created_by = visible_users.id)))
LEFT JOIN organizations ON ((templates.organization_id = organizations.id)));
COMMENT ON VIEW template_with_names IS 'Joins in the display name information such as username, avatar, and organization name.';
ALTER TABLE templates DROP COLUMN cors_behavior;
DROP TYPE IF EXISTS cors_behavior;
@@ -0,0 +1,52 @@
CREATE TYPE cors_behavior AS ENUM (
'simple',
'passthru'
);
ALTER TABLE templates
ADD COLUMN cors_behavior cors_behavior NOT NULL DEFAULT 'simple'::cors_behavior;
-- Update the template_with_users view by recreating it.
DROP VIEW IF EXISTS template_with_names;
CREATE VIEW template_with_names AS
SELECT templates.id,
templates.created_at,
templates.updated_at,
templates.organization_id,
templates.deleted,
templates.name,
templates.provisioner,
templates.active_version_id,
templates.description,
templates.default_ttl,
templates.created_by,
templates.icon,
templates.user_acl,
templates.group_acl,
templates.display_name,
templates.allow_user_cancel_workspace_jobs,
templates.allow_user_autostart,
templates.allow_user_autostop,
templates.failure_ttl,
templates.time_til_dormant,
templates.time_til_dormant_autodelete,
templates.autostop_requirement_days_of_week,
templates.autostop_requirement_weeks,
templates.autostart_block_days_of_week,
templates.require_active_version,
templates.deprecated,
templates.activity_bump,
templates.max_port_sharing_level,
templates.use_classic_parameter_flow,
templates.cors_behavior, -- <--- adding this column
COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url,
COALESCE(visible_users.username, ''::text) AS created_by_username,
COALESCE(visible_users.name, ''::text) AS created_by_name,
COALESCE(organizations.name, ''::text) AS organization_name,
COALESCE(organizations.display_name, ''::text) AS organization_display_name,
COALESCE(organizations.icon, ''::text) AS organization_icon
FROM ((templates
LEFT JOIN visible_users ON ((templates.created_by = visible_users.id)))
LEFT JOIN organizations ON ((templates.organization_id = organizations.id)));
COMMENT ON VIEW template_with_names IS 'Joins in the display name information such as username, avatar, and organization name.';
+1
View File
@@ -120,6 +120,7 @@ func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplate
&i.ActivityBump,
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
+61 -1
View File
@@ -559,6 +559,64 @@ func AllConnectionTypeValues() []ConnectionType {
}
}
type CorsBehavior string
const (
CorsBehaviorSimple CorsBehavior = "simple"
CorsBehaviorPassthru CorsBehavior = "passthru"
)
func (e *CorsBehavior) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = CorsBehavior(s)
case string:
*e = CorsBehavior(s)
default:
return fmt.Errorf("unsupported scan type for CorsBehavior: %T", src)
}
return nil
}
type NullCorsBehavior struct {
CorsBehavior CorsBehavior `json:"cors_behavior"`
Valid bool `json:"valid"` // Valid is true if CorsBehavior is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullCorsBehavior) Scan(value interface{}) error {
if value == nil {
ns.CorsBehavior, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.CorsBehavior.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullCorsBehavior) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.CorsBehavior), nil
}
func (e CorsBehavior) Valid() bool {
switch e {
case CorsBehaviorSimple,
CorsBehaviorPassthru:
return true
}
return false
}
func AllCorsBehaviorValues() []CorsBehavior {
return []CorsBehavior{
CorsBehaviorSimple,
CorsBehaviorPassthru,
}
}
type CryptoKeyFeature string
const (
@@ -3474,6 +3532,7 @@ type Template struct {
ActivityBump int64 `db:"activity_bump" json:"activity_bump"`
MaxPortSharingLevel AppSharingLevel `db:"max_port_sharing_level" json:"max_port_sharing_level"`
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
CreatedByAvatarURL string `db:"created_by_avatar_url" json:"created_by_avatar_url"`
CreatedByUsername string `db:"created_by_username" json:"created_by_username"`
CreatedByName string `db:"created_by_name" json:"created_by_name"`
@@ -3521,7 +3580,8 @@ type TemplateTable struct {
ActivityBump int64 `db:"activity_bump" json:"activity_bump"`
MaxPortSharingLevel AppSharingLevel `db:"max_port_sharing_level" json:"max_port_sharing_level"`
// Determines whether to default to the dynamic parameter creation flow for this template or continue using the legacy classic parameter creation flow.This is a template wide setting, the template admin can revert to the classic flow if there are any issues. An escape hatch is required, as workspace creation is a core workflow and cannot break. This column will be removed when the dynamic parameter creation flow is stable.
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
}
// Records aggregated usage statistics for templates/users. All usage is rounded up to the nearest minute.
+18 -8
View File
@@ -11768,7 +11768,7 @@ func (q *sqlQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg GetTem
const getTemplateByID = `-- name: GetTemplateByID :one
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, cors_behavior, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon
FROM
template_with_names
WHERE
@@ -11810,6 +11810,7 @@ func (q *sqlQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (Templat
&i.ActivityBump,
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
@@ -11822,7 +11823,7 @@ func (q *sqlQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (Templat
const getTemplateByOrganizationAndName = `-- name: GetTemplateByOrganizationAndName :one
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, cors_behavior, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon
FROM
template_with_names AS templates
WHERE
@@ -11872,6 +11873,7 @@ func (q *sqlQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg G
&i.ActivityBump,
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
@@ -11883,7 +11885,7 @@ func (q *sqlQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg G
}
const getTemplates = `-- name: GetTemplates :many
SELECT id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon FROM template_with_names AS templates
SELECT id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, cors_behavior, created_by_avatar_url, created_by_username, created_by_name, organization_name, organization_display_name, organization_icon FROM template_with_names AS templates
ORDER BY (name, id) ASC
`
@@ -11926,6 +11928,7 @@ func (q *sqlQuerier) GetTemplates(ctx context.Context) ([]Template, error) {
&i.ActivityBump,
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
@@ -11948,7 +11951,7 @@ func (q *sqlQuerier) GetTemplates(ctx context.Context) ([]Template, error) {
const getTemplatesWithFilter = `-- name: GetTemplatesWithFilter :many
SELECT
t.id, t.created_at, t.updated_at, t.organization_id, t.deleted, t.name, t.provisioner, t.active_version_id, t.description, t.default_ttl, t.created_by, t.icon, t.user_acl, t.group_acl, t.display_name, t.allow_user_cancel_workspace_jobs, t.allow_user_autostart, t.allow_user_autostop, t.failure_ttl, t.time_til_dormant, t.time_til_dormant_autodelete, t.autostop_requirement_days_of_week, t.autostop_requirement_weeks, t.autostart_block_days_of_week, t.require_active_version, t.deprecated, t.activity_bump, t.max_port_sharing_level, t.use_classic_parameter_flow, t.created_by_avatar_url, t.created_by_username, t.created_by_name, t.organization_name, t.organization_display_name, t.organization_icon
t.id, t.created_at, t.updated_at, t.organization_id, t.deleted, t.name, t.provisioner, t.active_version_id, t.description, t.default_ttl, t.created_by, t.icon, t.user_acl, t.group_acl, t.display_name, t.allow_user_cancel_workspace_jobs, t.allow_user_autostart, t.allow_user_autostop, t.failure_ttl, t.time_til_dormant, t.time_til_dormant_autodelete, t.autostop_requirement_days_of_week, t.autostop_requirement_weeks, t.autostart_block_days_of_week, t.require_active_version, t.deprecated, t.activity_bump, t.max_port_sharing_level, t.use_classic_parameter_flow, t.cors_behavior, t.created_by_avatar_url, t.created_by_username, t.created_by_name, t.organization_name, t.organization_display_name, t.organization_icon
FROM
template_with_names AS t
LEFT JOIN
@@ -12059,6 +12062,7 @@ func (q *sqlQuerier) GetTemplatesWithFilter(ctx context.Context, arg GetTemplate
&i.ActivityBump,
&i.MaxPortSharingLevel,
&i.UseClassicParameterFlow,
&i.CorsBehavior,
&i.CreatedByAvatarURL,
&i.CreatedByUsername,
&i.CreatedByName,
@@ -12097,10 +12101,11 @@ INSERT INTO
display_name,
allow_user_cancel_workspace_jobs,
max_port_sharing_level,
use_classic_parameter_flow
use_classic_parameter_flow,
cors_behavior
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`
type InsertTemplateParams struct {
@@ -12120,6 +12125,7 @@ type InsertTemplateParams struct {
AllowUserCancelWorkspaceJobs bool `db:"allow_user_cancel_workspace_jobs" json:"allow_user_cancel_workspace_jobs"`
MaxPortSharingLevel AppSharingLevel `db:"max_port_sharing_level" json:"max_port_sharing_level"`
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
}
func (q *sqlQuerier) InsertTemplate(ctx context.Context, arg InsertTemplateParams) error {
@@ -12140,6 +12146,7 @@ func (q *sqlQuerier) InsertTemplate(ctx context.Context, arg InsertTemplateParam
arg.AllowUserCancelWorkspaceJobs,
arg.MaxPortSharingLevel,
arg.UseClassicParameterFlow,
arg.CorsBehavior,
)
return err
}
@@ -12240,7 +12247,8 @@ SET
allow_user_cancel_workspace_jobs = $7,
group_acl = $8,
max_port_sharing_level = $9,
use_classic_parameter_flow = $10
use_classic_parameter_flow = $10,
cors_behavior = $11
WHERE
id = $1
`
@@ -12256,6 +12264,7 @@ type UpdateTemplateMetaByIDParams struct {
GroupACL TemplateACL `db:"group_acl" json:"group_acl"`
MaxPortSharingLevel AppSharingLevel `db:"max_port_sharing_level" json:"max_port_sharing_level"`
UseClassicParameterFlow bool `db:"use_classic_parameter_flow" json:"use_classic_parameter_flow"`
CorsBehavior CorsBehavior `db:"cors_behavior" json:"cors_behavior"`
}
func (q *sqlQuerier) UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) error {
@@ -12270,6 +12279,7 @@ func (q *sqlQuerier) UpdateTemplateMetaByID(ctx context.Context, arg UpdateTempl
arg.GroupACL,
arg.MaxPortSharingLevel,
arg.UseClassicParameterFlow,
arg.CorsBehavior,
)
return err
}
@@ -19911,7 +19921,7 @@ LEFT JOIN LATERAL (
) latest_build ON TRUE
LEFT JOIN LATERAL (
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level, use_classic_parameter_flow, cors_behavior
FROM
templates
WHERE
+5 -3
View File
@@ -99,10 +99,11 @@ INSERT INTO
display_name,
allow_user_cancel_workspace_jobs,
max_port_sharing_level,
use_classic_parameter_flow
use_classic_parameter_flow,
cors_behavior
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17);
-- name: UpdateTemplateActiveVersionByID :exec
UPDATE
@@ -134,7 +135,8 @@ SET
allow_user_cancel_workspace_jobs = $7,
group_acl = $8,
max_port_sharing_level = $9,
use_classic_parameter_flow = $10
use_classic_parameter_flow = $10,
cors_behavior = $11
WHERE
id = $1
;
+1
View File
@@ -150,6 +150,7 @@ sql:
has_ai_task: HasAITask
ai_task_sidebar_app_id: AITaskSidebarAppID
latest_build_has_ai_task: LatestBuildHasAITask
cors_behavior: CorsBehavior
rules:
- name: do-not-use-public-schema-in-queries
message: "do not use public schema in queries"
@@ -744,6 +744,7 @@ func insertTemplates(t *testing.T, db database.Store, u database.User, org datab
MaxPortSharingLevel: database.AppSharingLevelAuthenticated,
CreatedBy: u.ID,
OrganizationID: org.ID,
CorsBehavior: database.CorsBehaviorSimple,
}))
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{})
@@ -763,6 +764,7 @@ func insertTemplates(t *testing.T, db database.Store, u database.User, org datab
MaxPortSharingLevel: database.AppSharingLevelAuthenticated,
CreatedBy: u.ID,
OrganizationID: org.ID,
CorsBehavior: database.CorsBehaviorSimple,
}))
require.NoError(t, db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{
+35 -1
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"sort"
"strings"
"time"
"github.com/go-chi/chi/v5"
@@ -29,6 +30,7 @@ import (
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/examples"
@@ -322,6 +324,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
autostopRequirementDaysOfWeekParsed uint8
autostartRequirementDaysOfWeekParsed uint8
maxPortShareLevel = database.AppSharingLevelOwner // default
corsBehavior = database.CorsBehaviorSimple // default
)
if defaultTTL < 0 {
validErrs = append(validErrs, codersdk.ValidationError{Field: "default_ttl_ms", Detail: "Must be a positive integer."})
@@ -351,6 +354,20 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
}
}
// Default the CORS behavior here to Simple so we don't break all existing templates.
val := database.CorsBehaviorSimple
if createTemplate.CORSBehavior != nil {
val = database.CorsBehavior(*createTemplate.CORSBehavior)
}
if !val.Valid() {
validErrs = append(validErrs, codersdk.ValidationError{
Field: "cors_behavior",
Detail: fmt.Sprintf("Invalid CORS behavior %q. Must be one of [%s]", *createTemplate.CORSBehavior, strings.Join(slice.ToStrings(database.AllCorsBehaviorValues()), ", ")),
})
} else {
corsBehavior = val
}
if autostopRequirementWeeks < 0 {
validErrs = append(validErrs, codersdk.ValidationError{Field: "autostop_requirement.weeks", Detail: "Must be a positive integer."})
}
@@ -409,6 +426,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
AllowUserCancelWorkspaceJobs: allowUserCancelWorkspaceJobs,
MaxPortSharingLevel: maxPortShareLevel,
UseClassicParameterFlow: useClassicParameterFlow,
CorsBehavior: corsBehavior,
})
if err != nil {
return xerrors.Errorf("insert template: %s", err)
@@ -725,6 +743,19 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
}
}
corsBehavior := template.CorsBehavior
if req.CORSBehavior != nil && *req.CORSBehavior != "" {
val := database.CorsBehavior(*req.CORSBehavior)
if !val.Valid() {
validErrs = append(validErrs, codersdk.ValidationError{
Field: "cors_behavior",
Detail: fmt.Sprintf("Invalid CORS behavior %q. Must be one of [%s]", *req.CORSBehavior, strings.Join(slice.ToStrings(database.AllCorsBehaviorValues()), ", ")),
})
} else {
corsBehavior = val
}
}
if len(validErrs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid request to update template metadata!",
@@ -759,7 +790,8 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
req.RequireActiveVersion == template.RequireActiveVersion &&
(deprecationMessage == template.Deprecated) &&
(classicTemplateFlow == template.UseClassicParameterFlow) &&
maxPortShareLevel == template.MaxPortSharingLevel {
maxPortShareLevel == template.MaxPortSharingLevel &&
corsBehavior == template.CorsBehavior {
return nil
}
@@ -801,6 +833,7 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
GroupACL: groupACL,
MaxPortSharingLevel: maxPortShareLevel,
UseClassicParameterFlow: classicTemplateFlow,
CorsBehavior: corsBehavior,
})
if err != nil {
return xerrors.Errorf("update template metadata: %w", err)
@@ -1084,6 +1117,7 @@ func (api *API) convertTemplate(
DeprecationMessage: templateAccessControl.Deprecated,
MaxPortShareLevel: maxPortShareLevel,
UseClassicParameterFlow: template.UseClassicParameterFlow,
CORSBehavior: codersdk.CORSBehavior(template.CorsBehavior),
}
}
+559 -5
View File
@@ -472,6 +472,409 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
})
})
t.Run("WorkspaceApplicationCORS", func(t *testing.T) {
t.Parallel()
const external = "https://example.com"
unauthenticatedClient := func(t *testing.T, appDetails *Details) *codersdk.Client {
c := appDetails.AppClient(t)
c.SetSessionToken("")
return c
}
authenticatedClient := func(t *testing.T, appDetails *Details) *codersdk.Client {
uc, _ := coderdtest.CreateAnotherUser(t, appDetails.SDKClient, appDetails.FirstUser.OrganizationID, rbac.RoleMember())
c := appDetails.AppClient(t)
c.SetSessionToken(uc.SessionToken())
return c
}
ownSubdomain := func(details *Details, app App) string {
url := details.SubdomainAppURL(app)
return url.Scheme + "://" + url.Host
}
externalOrigin := func(*Details, App) string {
return external
}
tests := []struct {
name string
app func(details *Details) App
client func(t *testing.T, appDetails *Details) *codersdk.Client
behavior codersdk.CORSBehavior
httpMethod string
origin func(details *Details, app App) string
expectedStatusCode int
checkRequestHeaders func(t *testing.T, origin string, req http.Header)
checkResponseHeaders func(t *testing.T, origin string, resp http.Header)
}{
// Public
{ // fails
// The default behavior is to accept preflight requests from the request origin if it matches the app's own subdomain.
name: "Default/Public/Preflight/Subdomain",
app: func(details *Details) App { return details.Apps.PublicCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: unauthenticatedClient,
httpMethod: http.MethodOptions,
origin: ownSubdomain,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Get("Access-Control-Allow-Methods"), http.MethodGet)
assert.Equal(t, "true", resp.Get("Access-Control-Allow-Credentials"))
},
},
{ // passes
// The default behavior is to reject preflight requests from origins other than the app's own subdomain.
name: "Default/Public/Preflight/External",
app: func(details *Details) App { return details.Apps.PublicCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: unauthenticatedClient,
httpMethod: http.MethodOptions,
origin: externalOrigin,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
// We don't add a valid Allow-Origin header for requests we won't proxy.
assert.Empty(t, resp.Get("Access-Control-Allow-Origin"))
},
},
{ // fails
// A request without an Origin header would be rejected by an actual browser since it lacks CORS headers.
name: "Default/Public/GET/NoOrigin",
app: func(details *Details) App { return details.Apps.PublicCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: unauthenticatedClient,
origin: func(*Details, App) string { return "" },
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Empty(t, resp.Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Get("Access-Control-Allow-Headers"))
assert.Empty(t, resp.Get("Access-Control-Allow-Credentials"))
// Added by the app handler.
assert.Equal(t, "simple", resp.Get("X-CORS-Handler"))
},
},
{ // fails
// The passthru behavior will pass through the request headers to the upstream app.
name: "Passthru/Public/Preflight/Subdomain",
app: func(details *Details) App { return details.Apps.PublicCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkRequestHeaders: func(t *testing.T, origin string, req http.Header) {
assert.Equal(t, origin, req.Get("Origin"))
assert.Equal(t, "GET", req.Get("Access-Control-Request-Method"))
},
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
{ // fails
// Identical to the previous test, but the origin is different.
name: "Passthru/Public/PreflightOther",
app: func(details *Details) App { return details.Apps.PublicCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkRequestHeaders: func(t *testing.T, origin string, req http.Header) {
assert.Equal(t, origin, req.Get("Origin"))
assert.Equal(t, "GET", req.Get("Access-Control-Request-Method"))
assert.Equal(t, "X-Got-Host", req.Get("Access-Control-Request-Headers"))
},
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
{
// A request without an Origin header would be rejected by an actual browser since it lacks CORS headers.
name: "Passthru/Public/GET/NoOrigin",
app: func(details *Details) App { return details.Apps.PublicCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: func(*Details, App) string { return "" },
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Empty(t, resp.Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Get("Access-Control-Allow-Headers"))
assert.Empty(t, resp.Get("Access-Control-Allow-Credentials"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
// Authenticated
{
// Same behavior as Default/Public/Preflight/Subdomain.
name: "Default/Authenticated/Preflight/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: authenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Get("Access-Control-Allow-Methods"), http.MethodGet)
assert.Equal(t, "true", resp.Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "X-Got-Host", resp.Get("Access-Control-Allow-Headers"))
},
},
{
// Same behavior as Default/Public/Preflight/External.
name: "Default/Authenticated/Preflight/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: authenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Empty(t, resp.Get("Access-Control-Allow-Origin"))
},
},
{
// An authenticated request to the app is allowed from its own subdomain.
name: "Default/Authenticated/GET/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: authenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", resp.Get("Access-Control-Allow-Credentials"))
// Added by the app handler.
assert.Equal(t, "simple", resp.Get("X-CORS-Handler"))
},
},
{
// An authenticated request to the app is allowed from an external origin.
// The origin doesn't match the app's own subdomain, so the CORS headers are not added.
name: "Default/Authenticated/GET/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSDefault },
behavior: codersdk.CORSBehaviorSimple,
client: authenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Empty(t, resp.Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Get("Access-Control-Allow-Headers"))
assert.Empty(t, resp.Get("Access-Control-Allow-Credentials"))
// Added by the app handler.
assert.Equal(t, "simple", resp.Get("X-CORS-Handler"))
},
},
{
// The request is rejected because the client is unauthenticated.
name: "Passthru/Unauthenticated/Preflight/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusSeeOther,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.NotEmpty(t, resp.Get("Location"))
},
},
{
// Same behavior as the above test, but the origin is different.
name: "Passthru/Unauthenticated/Preflight/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusSeeOther,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.NotEmpty(t, resp.Get("Location"))
},
},
{
// The request is rejected because the client is unauthenticated.
name: "Passthru/Unauthenticated/GET/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusSeeOther,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.NotEmpty(t, resp.Get("Location"))
},
},
{
// Same behavior as the above test, but the origin is different.
name: "Passthru/Unauthenticated/GET/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: unauthenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusSeeOther,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.NotEmpty(t, resp.Get("Location"))
},
},
{
// The request is allowed because the client is authenticated.
name: "Passthru/Authenticated/Preflight/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: authenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
{
// Same behavior as the above test, but the origin is different.
name: "Passthru/Authenticated/Preflight/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: authenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodOptions,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
{
// The request is allowed because the client is authenticated.
name: "Passthru/Authenticated/GET/Subdomain",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: authenticatedClient,
origin: ownSubdomain,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
{
// Same behavior as the above test, but the origin is different.
name: "Passthru/Authenticated/GET/External",
app: func(details *Details) App { return details.Apps.AuthenticatedCORSPassthru },
behavior: codersdk.CORSBehaviorPassthru,
client: authenticatedClient,
origin: externalOrigin,
httpMethod: http.MethodGet,
expectedStatusCode: http.StatusOK,
checkResponseHeaders: func(t *testing.T, origin string, resp http.Header) {
assert.Equal(t, origin, resp.Get("Access-Control-Allow-Origin"))
assert.Equal(t, http.MethodGet, resp.Get("Access-Control-Allow-Methods"))
// Added by the app handler.
assert.Equal(t, "passthru", resp.Get("X-CORS-Handler"))
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
var reqHeaders http.Header
// Setup an HTTP handler which is the "app"; this handler conditionally responds
// to requests based on the CORS behavior
appDetails := setupProxyTest(t, &DeploymentOptions{
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := r.Cookie(codersdk.SessionTokenCookie)
assert.ErrorIs(t, err, http.ErrNoCookie)
// Store the request headers for later assertions
reqHeaders = r.Header
switch tc.behavior {
case codersdk.CORSBehaviorPassthru:
w.Header().Set("X-CORS-Handler", "passthru")
// Only allow GET and OPTIONS requests
if r.Method != http.MethodGet && r.Method != http.MethodOptions {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// If the Origin header is present, add the CORS headers.
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", http.MethodGet)
}
w.WriteHeader(http.StatusOK)
case codersdk.CORSBehaviorSimple:
w.Header().Set("X-CORS-Handler", "simple")
}
}),
})
// Update the template CORS behavior.
b := tc.behavior
template, err := appDetails.SDKClient.UpdateTemplateMeta(ctx, appDetails.Workspace.TemplateID, codersdk.UpdateTemplateMeta{
CORSBehavior: &b,
})
require.NoError(t, err)
require.Equal(t, tc.behavior, template.CORSBehavior)
// Given: a client and a workspace app
client := tc.client(t, appDetails)
path := appDetails.SubdomainAppURL(tc.app(appDetails)).String()
origin := tc.origin(appDetails, tc.app(appDetails))
fmt.Println("method: ", tc.httpMethod)
// When: a preflight request is made to an app with a specified CORS behavior
resp, err := requestWithRetries(ctx, t, client, tc.httpMethod, path, nil, func(r *http.Request) {
// Mimic non-browser clients that don't send the Origin header.
if origin != "" {
r.Header.Set("Origin", origin)
}
r.Header.Set("Access-Control-Request-Method", "GET")
r.Header.Set("Access-Control-Request-Headers", "X-Got-Host")
})
require.NoError(t, err)
defer resp.Body.Close()
// Then: the request & response must match expectations
assert.Equal(t, tc.expectedStatusCode, resp.StatusCode)
assert.NoError(t, err)
if tc.checkRequestHeaders != nil {
tc.checkRequestHeaders(t, origin, reqHeaders)
}
tc.checkResponseHeaders(t, origin, resp.Header)
})
}
})
t.Run("WorkspaceApplicationAuth", func(t *testing.T) {
t.Parallel()
@@ -1340,6 +1743,153 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
})
})
t.Run("CORS", func(t *testing.T) {
t.Parallel()
// Set up test headers that should be returned by the app
testHeaders := http.Header{
"Access-Control-Allow-Origin": []string{"*"},
"Access-Control-Allow-Methods": []string{"GET, POST, OPTIONS"},
}
unauthenticatedClient := func(t *testing.T, appDetails *Details) *codersdk.Client {
c := appDetails.AppClient(t)
c.SetSessionToken("")
return c
}
authenticatedClient := func(t *testing.T, appDetails *Details) *codersdk.Client {
uc, _ := coderdtest.CreateAnotherUser(t, appDetails.SDKClient, appDetails.FirstUser.OrganizationID, rbac.RoleMember())
c := appDetails.AppClient(t)
c.SetSessionToken(uc.SessionToken())
return c
}
ownerClient := func(t *testing.T, appDetails *Details) *codersdk.Client {
c := appDetails.AppClient(t) // <-- Use same server as others
c.SetSessionToken(appDetails.SDKClient.SessionToken()) // But with owner auth
return c
}
tests := []struct {
name string
shareLevel codersdk.WorkspaceAgentPortShareLevel
behavior codersdk.CORSBehavior
client func(t *testing.T, appDetails *Details) *codersdk.Client
expectedStatusCode int
expectedCORSHeaders bool
}{
// Public
{
name: "Default/Public",
shareLevel: codersdk.WorkspaceAgentPortShareLevelPublic,
behavior: codersdk.CORSBehaviorSimple,
expectedCORSHeaders: false,
client: unauthenticatedClient,
expectedStatusCode: http.StatusOK,
},
{ // fails
name: "Passthru/Public",
shareLevel: codersdk.WorkspaceAgentPortShareLevelPublic,
behavior: codersdk.CORSBehaviorPassthru,
expectedCORSHeaders: true,
client: unauthenticatedClient,
expectedStatusCode: http.StatusOK,
},
// Authenticated
{
name: "Default/Authenticated",
shareLevel: codersdk.WorkspaceAgentPortShareLevelAuthenticated,
behavior: codersdk.CORSBehaviorSimple,
expectedCORSHeaders: false,
client: authenticatedClient,
expectedStatusCode: http.StatusOK,
},
{
name: "Passthru/Authenticated",
shareLevel: codersdk.WorkspaceAgentPortShareLevelAuthenticated,
behavior: codersdk.CORSBehaviorPassthru,
expectedCORSHeaders: true,
client: authenticatedClient,
expectedStatusCode: http.StatusOK,
},
{
// The CORS behavior will not affect unauthenticated requests.
// The request will be redirected to the login page.
name: "Passthru/Unauthenticated",
shareLevel: codersdk.WorkspaceAgentPortShareLevelAuthenticated,
behavior: codersdk.CORSBehaviorPassthru,
expectedCORSHeaders: false,
client: unauthenticatedClient,
expectedStatusCode: http.StatusSeeOther,
},
// Owner
{
name: "Default/Owner",
shareLevel: codersdk.WorkspaceAgentPortShareLevelAuthenticated, // Owner is not a valid share level for ports.
behavior: codersdk.CORSBehaviorSimple,
expectedCORSHeaders: false,
client: ownerClient,
expectedStatusCode: http.StatusOK,
},
{ // fails
name: "Passthru/Owner",
shareLevel: codersdk.WorkspaceAgentPortShareLevelAuthenticated, // Owner is not a valid share level for ports.
behavior: codersdk.CORSBehaviorPassthru,
expectedCORSHeaders: true,
client: ownerClient,
expectedStatusCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
appDetails := setupProxyTest(t, &DeploymentOptions{
headers: testHeaders,
})
port, err := strconv.ParseInt(appDetails.Apps.Port.AppSlugOrPort, 10, 32)
require.NoError(t, err)
// Update the template CORS behavior.
b := tc.behavior
template, err := appDetails.SDKClient.UpdateTemplateMeta(ctx, appDetails.Workspace.TemplateID, codersdk.UpdateTemplateMeta{
CORSBehavior: &b,
})
require.NoError(t, err)
require.Equal(t, tc.behavior, template.CORSBehavior)
// Set the port we have to be shared.
_, err = appDetails.SDKClient.UpsertWorkspaceAgentPortShare(ctx, appDetails.Workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{
AgentName: proxyTestAgentName,
Port: int32(port),
ShareLevel: tc.shareLevel,
Protocol: codersdk.WorkspaceAgentPortShareProtocolHTTP,
})
require.NoError(t, err)
client := tc.client(t, appDetails)
resp, err := requestWithRetries(ctx, t, client, http.MethodGet, appDetails.SubdomainAppURL(appDetails.Apps.Port).String(), nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
if tc.expectedCORSHeaders {
require.Equal(t, testHeaders.Get("Access-Control-Allow-Origin"), resp.Header.Get("Access-Control-Allow-Origin"), "allow origin did not match")
require.Equal(t, testHeaders.Get("Access-Control-Allow-Methods"), resp.Header.Get("Access-Control-Allow-Methods"), "allow methods did not match")
} else {
require.Empty(t, resp.Header.Get("Access-Control-Allow-Origin"))
require.Empty(t, resp.Header.Get("Access-Control-Allow-Methods"))
}
})
}
})
t.Run("AppSharing", func(t *testing.T) {
t.Parallel()
@@ -1386,7 +1936,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
forceURLTransport(t, client)
// Create workspace.
port := appServer(t, nil, false)
port := appServer(t, nil, false, nil)
workspace, _ = createWorkspaceWithApps(t, client, user.OrganizationIDs[0], user, port, false)
// Verify that the apps have the correct sharing levels set.
@@ -1397,10 +1947,14 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
agnt = workspaceBuild.Resources[0].Agents[0]
found := map[string]codersdk.WorkspaceAppSharingLevel{}
expected := map[string]codersdk.WorkspaceAppSharingLevel{
proxyTestAppNameFake: codersdk.WorkspaceAppSharingLevelOwner,
proxyTestAppNameOwner: codersdk.WorkspaceAppSharingLevelOwner,
proxyTestAppNameAuthenticated: codersdk.WorkspaceAppSharingLevelAuthenticated,
proxyTestAppNamePublic: codersdk.WorkspaceAppSharingLevelPublic,
proxyTestAppNameFake: codersdk.WorkspaceAppSharingLevelOwner,
proxyTestAppNameOwner: codersdk.WorkspaceAppSharingLevelOwner,
proxyTestAppNameAuthenticated: codersdk.WorkspaceAppSharingLevelAuthenticated,
proxyTestAppNamePublic: codersdk.WorkspaceAppSharingLevelPublic,
proxyTestAppNameAuthenticatedCORSPassthru: codersdk.WorkspaceAppSharingLevelAuthenticated,
proxyTestAppNamePublicCORSPassthru: codersdk.WorkspaceAppSharingLevelPublic,
proxyTestAppNameAuthenticatedCORSDefault: codersdk.WorkspaceAppSharingLevelAuthenticated,
proxyTestAppNamePublicCORSDefault: codersdk.WorkspaceAppSharingLevelPublic,
}
for _, app := range agnt.Apps {
found[app.DisplayName] = app.SharingLevel
+102 -25
View File
@@ -36,8 +36,13 @@ const (
proxyTestAppNameOwner = "test-app-owner"
proxyTestAppNameAuthenticated = "test-app-authenticated"
proxyTestAppNamePublic = "test-app-public"
proxyTestAppQuery = "query=true"
proxyTestAppBody = "hello world from apps test"
// nolint:gosec // Not a secret
proxyTestAppNameAuthenticatedCORSPassthru = "test-app-authenticated-cors-passthru"
proxyTestAppNamePublicCORSPassthru = "test-app-public-cors-passthru"
proxyTestAppNameAuthenticatedCORSDefault = "test-app-authenticated-cors-default"
proxyTestAppNamePublicCORSDefault = "test-app-public-cors-default"
proxyTestAppQuery = "query=true"
proxyTestAppBody = "hello world from apps test"
proxyTestSubdomainRaw = "*.test.coder.com"
proxyTestSubdomain = "test.coder.com"
@@ -60,6 +65,7 @@ type DeploymentOptions struct {
noWorkspace bool
port uint16
headers http.Header
handler http.Handler
}
// Deployment is a license-agnostic deployment with all the fields that apps
@@ -93,6 +99,9 @@ type App struct {
// Prefix should have ---.
Prefix string
Query string
// Control the behavior of CORS handling.
CORSBehavior codersdk.CORSBehavior
}
// Details are the full test details returned from setupProxyTestWithFactory.
@@ -109,12 +118,16 @@ type Details struct {
AppPort uint16
Apps struct {
Fake App
Owner App
Authenticated App
Public App
Port App
PortHTTPS App
Fake App
Owner App
Authenticated App
Public App
Port App
PortHTTPS App
PublicCORSPassthru App
AuthenticatedCORSPassthru App
PublicCORSDefault App
AuthenticatedCORSDefault App
}
}
@@ -201,7 +214,7 @@ func setupProxyTestWithFactory(t *testing.T, factory DeploymentFactory, opts *De
}
if opts.port == 0 {
opts.port = appServer(t, opts.headers, opts.ServeHTTPS)
opts.port = appServer(t, opts.headers, opts.ServeHTTPS, opts.handler)
}
workspace, agnt := createWorkspaceWithApps(t, deployment.SDKClient, deployment.FirstUser.OrganizationID, me, opts.port, opts.ServeHTTPS)
@@ -252,30 +265,64 @@ func setupProxyTestWithFactory(t *testing.T, factory DeploymentFactory, opts *De
AgentName: agnt.Name,
AppSlugOrPort: strconv.Itoa(int(opts.port)) + "s",
}
details.Apps.PublicCORSPassthru = App{
Username: me.Username,
WorkspaceName: workspace.Name,
AgentName: agnt.Name,
AppSlugOrPort: proxyTestAppNamePublicCORSPassthru,
CORSBehavior: codersdk.CORSBehaviorPassthru,
Query: proxyTestAppQuery,
}
details.Apps.AuthenticatedCORSPassthru = App{
Username: me.Username,
WorkspaceName: workspace.Name,
AgentName: agnt.Name,
AppSlugOrPort: proxyTestAppNameAuthenticatedCORSPassthru,
CORSBehavior: codersdk.CORSBehaviorPassthru,
Query: proxyTestAppQuery,
}
details.Apps.PublicCORSDefault = App{
Username: me.Username,
WorkspaceName: workspace.Name,
AgentName: agnt.Name,
AppSlugOrPort: proxyTestAppNamePublicCORSDefault,
Query: proxyTestAppQuery,
}
details.Apps.AuthenticatedCORSDefault = App{
Username: me.Username,
WorkspaceName: workspace.Name,
AgentName: agnt.Name,
AppSlugOrPort: proxyTestAppNameAuthenticatedCORSDefault,
Query: proxyTestAppQuery,
}
return details
}
//nolint:revive
func appServer(t *testing.T, headers http.Header, isHTTPS bool) uint16 {
server := httptest.NewUnstartedServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
_, err := r.Cookie(codersdk.SessionTokenCookie)
assert.ErrorIs(t, err, http.ErrNoCookie)
w.Header().Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
w.Header().Set("X-Got-Host", r.Host)
for name, values := range headers {
for _, value := range values {
w.Header().Add(name, value)
}
func appServer(t *testing.T, headers http.Header, isHTTPS bool, handler http.Handler) uint16 {
defaultHandler := http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
_, err := r.Cookie(codersdk.SessionTokenCookie)
assert.ErrorIs(t, err, http.ErrNoCookie)
w.Header().Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
w.Header().Set("X-Got-Host", r.Host)
for name, values := range headers {
for _, value := range values {
w.Header().Add(name, value)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(proxyTestAppBody))
},
),
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(proxyTestAppBody))
},
)
if handler == nil {
handler = defaultHandler
}
server := httptest.NewUnstartedServer(handler)
server.Config.ReadHeaderTimeout = time.Minute
if isHTTPS {
server.StartTLS()
@@ -361,6 +408,36 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U
Url: appURL,
Subdomain: true,
},
{
Slug: proxyTestAppNamePublicCORSPassthru,
DisplayName: proxyTestAppNamePublicCORSPassthru,
SharingLevel: proto.AppSharingLevel_PUBLIC,
Url: appURL,
Subdomain: true,
// CorsBehavior: proto.AppCORSBehavior_PASSTHRU,
},
{
Slug: proxyTestAppNameAuthenticatedCORSPassthru,
DisplayName: proxyTestAppNameAuthenticatedCORSPassthru,
SharingLevel: proto.AppSharingLevel_AUTHENTICATED,
Url: appURL,
Subdomain: true,
// CorsBehavior: proto.AppCORSBehavior_PASSTHRU,
},
{
Slug: proxyTestAppNamePublicCORSDefault,
DisplayName: proxyTestAppNamePublicCORSDefault,
SharingLevel: proto.AppSharingLevel_PUBLIC,
Url: appURL,
Subdomain: true,
},
{
Slug: proxyTestAppNameAuthenticatedCORSDefault,
DisplayName: proxyTestAppNameAuthenticatedCORSDefault,
SharingLevel: proto.AppSharingLevel_AUTHENTICATED,
Url: appURL,
Subdomain: true,
},
}
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
+21
View File
@@ -0,0 +1,21 @@
package cors
import (
"context"
"github.com/coder/coder/v2/codersdk"
)
type contextKeyBehavior struct{}
// WithBehavior sets the CORS behavior for the given context.
func WithBehavior(ctx context.Context, behavior codersdk.CORSBehavior) context.Context {
return context.WithValue(ctx, contextKeyBehavior{}, behavior)
}
// HasBehavior returns true if the given context has the specified CORS behavior.
func HasBehavior(ctx context.Context, behavior codersdk.CORSBehavior) bool {
val := ctx.Value(contextKeyBehavior{})
b, ok := val.(codersdk.CORSBehavior)
return ok && b == behavior
}
+1
View File
@@ -151,6 +151,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
if dbReq.AppURL != nil {
token.AppURL = dbReq.AppURL.String()
}
token.CORSBehavior = codersdk.CORSBehavior(dbReq.CorsBehavior)
// Verify the user has access to the app.
authed, warnings, err := p.authorizeRequest(r.Context(), authz, dbReq)
+6 -5
View File
@@ -301,11 +301,12 @@ func Test_ResolveRequest(t *testing.T) {
RegisteredClaims: jwtutils.RegisteredClaims{
Expiry: jwt.NewNumericDate(token.Expiry.Time()),
},
Request: req,
UserID: me.ID,
WorkspaceID: workspace.ID,
AgentID: agentID,
AppURL: appURL,
Request: req,
UserID: me.ID,
WorkspaceID: workspace.ID,
AgentID: agentID,
AppURL: appURL,
CORSBehavior: codersdk.CORSBehaviorSimple,
}, token)
require.NotZero(t, token.Expiry)
require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry.Time(), time.Minute)
+65 -29
View File
@@ -28,6 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspaceapps/cors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/site"
@@ -323,6 +324,37 @@ func (s *Server) workspaceAppsProxyPath(rw http.ResponseWriter, r *http.Request)
s.proxyWorkspaceApp(rw, r, *token, chiPath, appurl.ApplicationURL{})
}
// determineCORSBehavior examines the given token and conditionally applies
// CORS middleware if the token specifies that behavior.
func (s *Server) determineCORSBehavior(token *SignedToken, app appurl.ApplicationURL) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
// Create the CORS middleware handler upfront.
corsHandler := httpmw.WorkspaceAppCors(s.HostnameRegex, app)(next)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var behavior codersdk.CORSBehavior
if token != nil {
behavior = token.CORSBehavior
}
// Add behavior to context regardless of which handler we use,
// since we will use this later on to determine if we should strip
// CORS headers in the response.
r = r.WithContext(cors.WithBehavior(r.Context(), behavior))
switch behavior {
case codersdk.CORSBehaviorPassthru:
// Bypass the CORS middleware.
next.ServeHTTP(rw, r)
return
default:
// Apply the CORS middleware.
corsHandler.ServeHTTP(rw, r)
}
})
}
}
// HandleSubdomain handles subdomain-based application proxy requests (aka.
// DevURLs in Coder V1).
//
@@ -394,36 +426,36 @@ func (s *Server) HandleSubdomain(middlewares ...func(http.Handler) http.Handler)
return
}
// Use the passed in app middlewares before checking authentication and
// passing to the proxy app.
mws := chi.Middlewares(append(middlewares, httpmw.WorkspaceAppCors(s.HostnameRegex, app)))
mws.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if !s.handleAPIKeySmuggling(rw, r, AccessMethodSubdomain) {
return
}
if !s.handleAPIKeySmuggling(rw, r, AccessMethodSubdomain) {
return
}
token, ok := ResolveRequest(rw, r, ResolveRequestOptions{
Logger: s.Logger,
CookieCfg: s.Cookies,
SignedTokenProvider: s.SignedTokenProvider,
DashboardURL: s.DashboardURL,
PathAppBaseURL: s.AccessURL,
AppHostname: s.Hostname,
AppRequest: Request{
AccessMethod: AccessMethodSubdomain,
BasePath: "/",
Prefix: app.Prefix,
UsernameOrID: app.Username,
WorkspaceNameOrID: app.WorkspaceName,
AgentNameOrID: app.AgentName,
AppSlugOrPort: app.AppSlugOrPort,
},
AppPath: r.URL.Path,
AppQuery: r.URL.RawQuery,
})
if !ok {
return
}
// Generate a signed token for the request.
token, ok := ResolveRequest(rw, r, ResolveRequestOptions{
Logger: s.Logger,
SignedTokenProvider: s.SignedTokenProvider,
DashboardURL: s.DashboardURL,
PathAppBaseURL: s.AccessURL,
AppHostname: s.Hostname,
AppRequest: Request{
AccessMethod: AccessMethodSubdomain,
BasePath: "/",
Prefix: app.Prefix,
UsernameOrID: app.Username,
WorkspaceNameOrID: app.WorkspaceName,
AgentNameOrID: app.AgentName,
AppSlugOrPort: app.AppSlugOrPort,
},
AppPath: r.URL.Path,
AppQuery: r.URL.RawQuery,
})
if !ok {
return
}
// Proxy the request (possibly with the CORS middleware).
mws := chi.Middlewares(append(middlewares, s.determineCORSBehavior(token, app)))
mws.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
s.proxyWorkspaceApp(rw, r, *token, r.URL.Path, app)
})).ServeHTTP(rw, r.WithContext(ctx))
})
@@ -560,6 +592,10 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
proxy := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID, app, s.Hostname)
proxy.ModifyResponse = func(r *http.Response) error {
// If passthru behavior is set, disable our CORS header stripping.
if cors.HasBehavior(r.Request.Context(), codersdk.CORSBehaviorPassthru) {
return nil
}
r.Header.Del(httpmw.AccessControlAllowOriginHeader)
r.Header.Del(httpmw.AccessControlAllowCredentialsHeader)
r.Header.Del(httpmw.AccessControlAllowMethodsHeader)
+11
View File
@@ -204,6 +204,9 @@ type databaseRequest struct {
// AppSharingLevel is the sharing level of the app. This is forced to be set
// to AppSharingLevelOwner if the access method is terminal.
AppSharingLevel database.AppSharingLevel
// CorsBehavior is set at the template level for all apps/ports in a workspace, and can
// either be the current CORS middleware 'simple' or bypass the cors middleware with 'passthru'.
CorsBehavior database.CorsBehavior
}
// getDatabase does queries to get the owner user, workspace and agent
@@ -296,7 +299,14 @@ func (r Request) getDatabase(ctx context.Context, db database.Store) (*databaseR
// First check if it's a port-based URL with an optional "s" suffix for HTTPS.
potentialPortStr = strings.TrimSuffix(r.AppSlugOrPort, "s")
portUint, portUintErr = strconv.ParseUint(potentialPortStr, 10, 16)
corsBehavior database.CorsBehavior
)
tmpl, err := db.GetTemplateByID(ctx, workspace.TemplateID)
if err != nil {
return nil, xerrors.Errorf("get template %q: %w", workspace.TemplateID, err)
}
corsBehavior = tmpl.CorsBehavior
//nolint:nestif
if portUintErr == nil {
protocol := "http"
@@ -417,6 +427,7 @@ func (r Request) getDatabase(ctx context.Context, db database.Store) (*databaseR
App: app,
AppURL: appURLParsed,
AppSharingLevel: appSharingLevel,
CorsBehavior: corsBehavior,
}, nil
}
+5 -4
View File
@@ -22,10 +22,11 @@ type SignedToken struct {
// Request details.
Request `json:"request"`
UserID uuid.UUID `json:"user_id"`
WorkspaceID uuid.UUID `json:"workspace_id"`
AgentID uuid.UUID `json:"agent_id"`
AppURL string `json:"app_url"`
UserID uuid.UUID `json:"user_id"`
WorkspaceID uuid.UUID `json:"workspace_id"`
AgentID uuid.UUID `json:"agent_id"`
AppURL string `json:"app_url"`
CORSBehavior codersdk.CORSBehavior `json:"cors_behavior"`
}
// MatchesRequest returns true if the token matches the request. Any token that