diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index bd79455415..21104167ad 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1143,7 +1143,7 @@ func MustWorkspace(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID) // RequestExternalAuthCallback makes a request with the proper OAuth2 state cookie // to the external auth callback endpoint. -func RequestExternalAuthCallback(t testing.TB, providerID string, client *codersdk.Client) *http.Response { +func RequestExternalAuthCallback(t testing.TB, providerID string, client *codersdk.Client, opts ...func(*http.Request)) *http.Response { client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } @@ -1160,6 +1160,9 @@ func RequestExternalAuthCallback(t testing.TB, providerID string, client *coders Name: codersdk.SessionTokenCookie, Value: client.SessionToken(), }) + for _, opt := range opts { + opt(req) + } res, err := client.HTTPClient.Do(req) require.NoError(t, err) t.Cleanup(func() { diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 09e4c61b68..5cc235fbda 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -479,7 +479,6 @@ func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenCla // This is a niche case, but it is needed for testing ConvertLoginType. func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() - path := "/api/v2/users/oidc/callback" if f.callbackPath != "" { path = f.callbackPath @@ -489,13 +488,23 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken f.SetRedirect(t, coderOauthURL.String()) cli := f.HTTPClient(client.HTTPClient) - cli.CheckRedirect = func(req *http.Request, via []*http.Request) error { + redirectFn := cli.CheckRedirect + checkRedirect := func(req *http.Request, via []*http.Request) error { // Store the idTokenClaims to the specific state request. This ties // the claims 1:1 with a given authentication flow. - state := req.URL.Query().Get("state") - f.stateToIDTokenClaims.Store(state, idTokenClaims) + if state := req.URL.Query().Get("state"); state != "" { + f.stateToIDTokenClaims.Store(state, idTokenClaims) + return nil + } + // This is mainly intended to prevent the _last_ redirect + // The one involving the state param is a core part of the + // OIDC flow and shouldn't be redirected. + if redirectFn != nil { + return redirectFn(req, via) + } return nil } + cli.CheckRedirect = checkRedirect req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil) require.NoError(t, err) diff --git a/coderd/externalauth.go b/coderd/externalauth.go index 25f362e737..a07f6d486c 100644 --- a/coderd/externalauth.go +++ b/coderd/externalauth.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "github.com/sqlc-dev/pqtype" "golang.org/x/sync/errgroup" @@ -306,6 +307,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht // FE know not to enter the authentication loop again, and instead display an error. redirect = fmt.Sprintf("/external-auth/%s?redirected=true", externalAuthConfig.ID) } + redirect = uriFromURL(redirect) http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } } @@ -401,3 +403,12 @@ func ExternalAuthConfig(cfg *externalauth.Config) codersdk.ExternalAuthLinkProvi AllowValidate: cfg.ValidateURL != "", } } + +func uriFromURL(u string) string { + uri, err := url.Parse(u) + if err != nil { + return "/" + } + + return uri.RequestURI() +} diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index a62e7eab74..87197528fc 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -207,12 +207,12 @@ func TestExternalAuthManagement(t *testing.T) { const gitlabID = "fake-gitlab" githubCalled := false - githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error { + githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(_ string) error { githubCalled = true return nil })) gitlabCalled := false - gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error { + gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(_ string) error { gitlabCalled = true return nil })) @@ -508,6 +508,35 @@ func TestExternalAuthCallback(t *testing.T) { resp = coderdtest.RequestExternalAuthCallback(t, "github", client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) }) + + t.Run("CustomRedirect", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{ + IncludeProvisionerDaemon: true, + ExternalAuthConfigs: []*externalauth.Config{{ + InstrumentedOAuth2Config: &testutil.OAuth2Config{}, + ID: "github", + Regex: regexp.MustCompile(`github\.com`), + Type: codersdk.EnhancedExternalAuthProviderGitHub.String(), + }}, + }) + maliciousHost := "https://malicious.com" + expectedURI := "/some/path?param=1" + _ = coderdtest.CreateFirstUser(t, client) + resp := coderdtest.RequestExternalAuthCallback(t, "github", client, func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: codersdk.OAuth2RedirectCookie, + Value: maliciousHost + expectedURI, + }) + }) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + location, err := resp.Location() + require.NoError(t, err) + require.Equal(t, expectedURI, location.RequestURI()) + require.Equal(t, client.URL.Host, location.Host) + require.NotContains(t, location.String(), maliciousHost) + }) + t.Run("ValidateURL", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index 98baaae4c4..7afa622d97 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "reflect" "github.com/go-chi/chi/v5" @@ -85,6 +86,15 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") + redirect := r.URL.Query().Get("redirect") + if redirect != "" { + // We want to ensure that we're only ever redirecting to the application. + // We could be more strict here and check to see if the host matches + // the host of the AccessURL but ultimately as long as our redirect + // url omits a host we're ensuring that we're routing to a path + // local to the application. + redirect = uriFromURL(redirect) + } if code == "" { // If the code isn't provided, we'll redirect! @@ -119,7 +129,7 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp // an old redirect could apply! http.SetCookie(rw, &http.Cookie{ Name: codersdk.OAuth2RedirectCookie, - Value: r.URL.Query().Get("redirect"), + Value: redirect, Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, @@ -150,7 +160,6 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp return } - var redirect string stateRedirect, err := r.Cookie(codersdk.OAuth2RedirectCookie) if err == nil { redirect = stateRedirect.Value @@ -302,3 +311,12 @@ func ExtractOAuth2ProviderAppSecret(db database.Store) func(http.Handler) http.H }) } } + +func uriFromURL(u string) string { + uri, err := url.Parse(u) + if err != nil { + return "/" + } + + return uri.RequestURI() +} diff --git a/coderd/httpmw/oauth2_test.go b/coderd/httpmw/oauth2_test.go index 571e4fd9c4..ca5dcf5f8a 100644 --- a/coderd/httpmw/oauth2_test.go +++ b/coderd/httpmw/oauth2_test.go @@ -67,6 +67,31 @@ func TestOAuth2(t *testing.T) { cookie := res.Result().Cookies()[1] require.Equal(t, "/dashboard", cookie.Value) }) + t.Run("OnlyPathBaseRedirect", func(t *testing.T) { + t.Parallel() + // Construct a URI to a potentially malicious + // site and assert that we omit the host + // when redirecting the request. + uri := &url.URL{ + Scheme: "https", + Host: "some.bad.domain.com", + Path: "/sadf/asdfasdf", + RawQuery: "foo=hello&bar=world", + } + expectedValue := uri.Path + "?" + uri.RawQuery + req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape(uri.String()), nil) + res := httptest.NewRecorder() + tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline) + httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req) + location := res.Header().Get("Location") + if !assert.NotEmpty(t, location) { + return + } + require.Len(t, res.Result().Cookies(), 2) + cookie := res.Result().Cookies()[1] + require.Equal(t, expectedValue, cookie.Value) + }) + t.Run("NoState", func(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/?code=something", nil) @@ -108,7 +133,7 @@ func TestOAuth2(t *testing.T) { }) res := httptest.NewRecorder() tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline) - httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { state := httpmw.OAuth2(r) require.Equal(t, "/dashboard", state.Redirect) })).ServeHTTP(res, req) diff --git a/coderd/userauth.go b/coderd/userauth.go index bb149d9d07..644ed207df 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -707,9 +707,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) { http.SetCookie(rw, cookie) } - if redirect == "" { - redirect = "/" - } + redirect = uriFromURL(redirect) http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } @@ -1085,9 +1083,9 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) { } redirect := state.Redirect - if redirect == "" { - redirect = "/" - } + // Strip the host if it exists on the URL to prevent + // any nefarious redirects. + redirect = uriFromURL(redirect) http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) } @@ -1687,7 +1685,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data } } var claims OAuthConvertStateClaims - token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) { return api.OAuthSigningKey[:], nil }) if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid { diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 8e1f07e24d..6302bee390 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -354,11 +354,25 @@ func TestUserOAuth2Github(t *testing.T) { }) numLogs := len(auditor.AuditLogs()) - resp := oauth2Callback(t, client) + // Validate that attempting to redirect away from the + // site does not work. + maliciousHost := "https://malicious.com" + expectedPath := "/my/path" + resp := oauth2Callback(t, client, func(req *http.Request) { + // Add the cookie to bypass the parsing in httpmw/oauth2.go + req.AddCookie(&http.Cookie{ + Name: codersdk.OAuth2RedirectCookie, + Value: maliciousHost + expectedPath, + }) + }) numLogs++ // add an audit log for login require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - + redirect, err := resp.Location() + require.NoError(t, err) + require.Equal(t, expectedPath, redirect.Path) + require.Equal(t, client.URL.Host, redirect.Host) + require.NotContains(t, redirect.String(), maliciousHost) client.SetSessionToken(authCookieValue(resp.Cookies())) user, err := client.User(context.Background(), "me") require.NoError(t, err) @@ -1436,6 +1450,59 @@ func TestUserOIDC(t *testing.T) { _, resp := fake.AttemptLogin(t, client, jwt.MapClaims{}) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) + + t.Run("StripRedirectHost", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + expectedRedirect := "/foo/bar?hello=world&bar=baz" + redirectURL := "https://malicious" + expectedRedirect + + callbackPath := fmt.Sprintf("/api/v2/users/oidc/callback?redirect=%s", url.QueryEscape(redirectURL)) + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefresh(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + oidctest.WithCallbackPath(callbackPath), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + + client := coderdtest.New(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + + client.HTTPClient.Transport = http.DefaultTransport + + client.HTTPClient.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + + claims := jwt.MapClaims{ + "email": "user@example.com", + "email_verified": true, + } + + // Perform the login + loginClient, resp := fake.LoginWithClient(t, client, claims) + require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + + // Get the location from the response + location, err := resp.Location() + require.NoError(t, err) + + // Check that the redirect URL has been stripped of its malicious host + require.Equal(t, expectedRedirect, location.RequestURI()) + require.Equal(t, client.URL.Host, location.Host) + require.NotContains(t, location.String(), "malicious") + + // Verify the user was created + user, err := loginClient.User(ctx, "me") + require.NoError(t, err) + require.Equal(t, "user@example.com", user.Email) + }) } func TestUserLogout(t *testing.T) { @@ -1587,7 +1654,7 @@ func TestOIDCSkipIssuer(t *testing.T) { require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC) } -func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { +func oauth2Callback(t *testing.T, client *codersdk.Client, opts ...func(*http.Request)) *http.Response { client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } @@ -1597,6 +1664,9 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { require.NoError(t, err) req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) require.NoError(t, err) + for _, opt := range opts { + opt(req) + } req.AddCookie(&http.Cookie{ Name: codersdk.OAuth2StateCookie, Value: state,