Files
coder/coderd/httpmw/httpmw_internal_test.go
T
Danielle Maywood 92a6d6c2c0 chore: remove unnecessary loop variable captures (#22180)
Since Go 1.22, the loop variable capture issue is resolved. Variables
declared by for loops are now per-iteration rather than per-loop, making
the 'v := v' pattern unnecessary.
2026-02-19 09:02:19 +00:00

262 lines
5.6 KiB
Go

package httpmw
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
)
const (
testParam = "workspaceagent"
testWorkspaceAgentID = "8a70c576-12dc-42bc-b791-112a32b5bd43"
)
func TestParseUUID_Valid(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
ctx := chi.NewRouteContext()
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.True(t, ok, "UUID should be parsed")
assert.Equal(t, testWorkspaceAgentID, parsed.String())
}
func TestParseUUID_Invalid(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
ctx := chi.NewRouteContext()
ctx.URLParams.Add(testParam, "wrong-id")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.False(t, ok, "UUID should not be parsed")
assert.Equal(t, http.StatusBadRequest, rw.Code)
var response codersdk.Response
err := json.Unmarshal(rw.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
}
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
func TestNormalizeAudienceURI(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHTTPWithoutTrailingSlash",
input: "http://example.com",
expected: "http://example.com/",
},
{
name: "SimpleHTTPWithTrailingSlash",
input: "http://example.com/",
expected: "http://example.com/",
},
{
name: "HTTPSWithPath",
input: "https://api.example.com/v1/",
expected: "https://api.example.com/v1",
},
{
name: "CaseNormalization",
input: "HTTPS://API.EXAMPLE.COM/V1/",
expected: "https://api.example.com/V1",
},
{
name: "DefaultHTTPPort",
input: "http://example.com:80/api/",
expected: "http://example.com/api",
},
{
name: "DefaultHTTPSPort",
input: "https://example.com:443/api/",
expected: "https://example.com/api",
},
{
name: "NonDefaultPort",
input: "http://example.com:8080/api/",
expected: "http://example.com:8080/api",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeAudienceURI(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizeHost tests host normalization including IDN support
func TestNormalizeHost(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHost",
input: "example.com",
expected: "example.com",
},
{
name: "HostWithPort",
input: "example.com:8080",
expected: "example.com:8080",
},
{
name: "CaseNormalization",
input: "EXAMPLE.COM",
expected: "example.com",
},
{
name: "IPv4Address",
input: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6Address",
input: "[::1]:8080",
expected: "[::1]:8080",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeHost(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizePathSegments tests path normalization including dot-segment removal
func TestNormalizePathSegments(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "/",
},
{
name: "SimplePath",
input: "/api/v1",
expected: "/api/v1",
},
{
name: "PathWithDotSegments",
input: "/api/../v1/./test",
expected: "/v1/test",
},
{
name: "TrailingSlash",
input: "/api/v1/",
expected: "/api/v1",
},
{
name: "MultipleSlashes",
input: "/api//v1///test",
expected: "/api//v1///test",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizePathSegments(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestExtractExpectedAudience tests audience extraction from HTTP requests
func TestExtractExpectedAudience(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
scheme string
host string
path string
expected string
}{
{
name: "SimpleHTTP",
scheme: "http",
host: "example.com",
path: "/api/test",
expected: "http://example.com/",
},
{
name: "HTTPS",
scheme: "https",
host: "api.example.com",
path: "/v1/users",
expected: "https://api.example.com/",
},
{
name: "WithPort",
scheme: "http",
host: "localhost:8080",
path: "/api",
expected: "http://localhost:8080/",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var req *http.Request
if tc.scheme == "https" {
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
} else {
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
}
req.Host = tc.host
result := extractExpectedAudience(nil, req)
assert.Equal(t, tc.expected, result)
})
}
}