diff --git a/cli/server.go b/cli/server.go index 9a36ff6e9b..09674a1c91 100644 --- a/cli/server.go +++ b/cli/server.go @@ -2909,6 +2909,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder provider.MCPToolDenyRegex = v.Value case "PKCE_METHODS": provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ") + case "API_BASE_URL": + provider.APIBaseURL = v.Value } providers[providerNum] = provider } diff --git a/cli/server_test.go b/cli/server_test.go index b0b493570c..a0020b5f9a 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -108,6 +108,29 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) { }) } +func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) { + t.Parallel() + providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{ + "CODER_EXTERNAL_AUTH_0_TYPE=github", + "CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx", + "CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3", + }) + require.NoError(t, err) + require.Len(t, providers, 1) + assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL) +} + +func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) { + t.Parallel() + providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{ + "CODER_EXTERNAL_AUTH_0_TYPE=github", + "CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx", + }) + require.NoError(t, err) + require.Len(t, providers, 1) + assert.Equal(t, "", providers[0].APIBaseURL) +} + // TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_` // environment variables are still supported. func TestReadGitAuthProvidersFromEnv(t *testing.T) { diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index f47c449057..de14a7dc30 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -15269,6 +15269,10 @@ const docTemplate = `{ "codersdk.ExternalAuthConfig": { "type": "object", "properties": { + "api_base_url": { + "description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.", + "type": "string" + }, "app_install_url": { "type": "string" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index bc1b58c27b..5d49746cae 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -13792,6 +13792,10 @@ "codersdk.ExternalAuthConfig": { "type": "object", "properties": { + "api_base_url": { + "description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.", + "type": "string" + }, "app_install_url": { "type": "string" }, diff --git a/coderd/chats.go b/coderd/chats.go index 24fd6f3d97..67ad866ee2 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -13,7 +13,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "regexp" "strconv" "strings" "sync" @@ -32,6 +31,8 @@ import ( "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" + "github.com/coder/coder/v2/coderd/gitsync" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/coderd/httpmw" @@ -39,16 +40,15 @@ import ( "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/websocket" ) const ( - chatDiffStatusTTL = 120 * time.Second - chatDiffBackgroundRefreshTimeout = 20 * time.Second - githubAPIBaseURL = "https://api.github.com" - chatStreamBatchSize = 256 + chatDiffStatusTTL = gitsync.DiffStatusTTL + chatStreamBatchSize = 256 chatContextLimitModelConfigKey = "context_limit" chatContextCompressionThresholdModelConfigKey = "context_compression_threshold" @@ -58,19 +58,6 @@ const ( maxSystemPromptLenBytes = 131072 // 128 KiB ) -// chatDiffRefreshBackoffSchedule defines the delays between successive -// background diff refresh attempts. The trigger fires when the agent -// obtains a GitHub token, which is typically right before a git push -// or PR creation. The backoff gives progressively more time for the -// push and any PR workflow to complete before querying the GitHub API. -var chatDiffRefreshBackoffSchedule = []time.Duration{ - 1 * time.Second, - 3 * time.Second, - 5 * time.Second, - 10 * time.Second, - 20 * time.Second, -} - // chatGitRef holds the branch and remote origin reported by the // workspace agent during a git operation. type chatGitRef struct { @@ -78,32 +65,6 @@ type chatGitRef struct { RemoteOrigin string } -var ( - githubPullRequestPathPattern = regexp.MustCompile( - `^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`, - ) - githubRepositoryHTTPSPattern = regexp.MustCompile( - `^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`, - ) - githubRepositorySSHPathPattern = regexp.MustCompile( - `^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`, - ) -) - -type githubPullRequestRef struct { - Owner string - Repo string - Number int -} - -type githubPullRequestStatus struct { - PullRequestState string - ChangesRequested bool - Additions int32 - Deletions int32 - ChangedFiles int32 -} - type chatRepositoryRef struct { Provider string RemoteOrigin string @@ -1249,193 +1210,6 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time, return chatDiffStatusIsStale(status, now) } -func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) { - if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil { - return - } - - go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) { - ctx := api.ctx - if ctx == nil { - ctx = context.Background() - } - //nolint:gocritic // Background goroutine for diff status refresh has no user context. - ctx = dbauthz.AsSystemRestricted(ctx) - - // Always store the git ref so the data is persisted even - // before a PR exists. The frontend can show branch info - // and the refresh loop can resolve a PR later. - api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef) - - for _, delay := range chatDiffRefreshBackoffSchedule { - t := api.Clock.NewTimer(delay, "chat_diff_refresh") - select { - case <-ctx.Done(): - t.Stop() - return - case <-t.C: - } - - // Refresh and publish status on every iteration. - // Stop the loop once a PR is discovered — there's - // nothing more to wait for after that. - if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) { - return - } - } - }(workspace.ID, workspace.OwnerID, chatID, gitRef) -} - -// storeChatGitRef persists the git branch and remote origin reported -// by the workspace agent on the chat that initiated the git operation. -// When chatID is set, only that specific chat is updated; otherwise all -// chats associated with the workspace are updated (legacy fallback). -func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) { - var chatsToUpdate []database.Chat - - if chatID.Valid { - chat, err := api.Database.GetChatByID(ctx, chatID.UUID) - if err != nil { - api.Logger.Warn(ctx, "failed to get chat for git ref storage", - slog.F("chat_id", chatID.UUID), - slog.F("workspace_id", workspaceID), - slog.Error(err), - ) - return - } - chatsToUpdate = []database.Chat{chat} - } else { - chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{ - OwnerID: workspaceOwnerID, - }) - if err != nil { - api.Logger.Warn(ctx, "failed to list chats for git ref storage", - slog.F("workspace_id", workspaceID), - slog.Error(err), - ) - return - } - chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID) - } - - for _, chat := range chatsToUpdate { - _, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{ - ChatID: chat.ID, - GitBranch: gitRef.Branch, - GitRemoteOrigin: gitRef.RemoteOrigin, - StaleAt: time.Now().UTC().Add(-time.Second), - Url: sql.NullString{}, - }) - if err != nil { - api.Logger.Warn(ctx, "failed to store git ref on chat diff status", - slog.F("chat_id", chat.ID), - slog.F("workspace_id", workspaceID), - slog.Error(err), - ) - continue - } - api.publishChatDiffStatusEvent(ctx, chat.ID) - } -} - -// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats -// associated with the given workspace. When chatID is set, only that -// specific chat is refreshed; otherwise all chats for the workspace -// are refreshed (legacy fallback). It returns true when every -// refreshed chat has a PR URL resolved, signaling that the caller -// can stop polling. -func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool { - var filtered []database.Chat - - if chatID.Valid { - chat, err := api.Database.GetChatByID(ctx, chatID.UUID) - if err != nil { - api.Logger.Warn(ctx, "failed to get chat for diff refresh", - slog.F("chat_id", chatID.UUID), - slog.F("workspace_id", workspaceID), - slog.Error(err), - ) - return false - } - filtered = []database.Chat{chat} - } else { - chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{ - OwnerID: workspaceOwnerID, - }) - if err != nil { - api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh", - slog.F("workspace_id", workspaceID), - slog.F("workspace_owner_id", workspaceOwnerID), - slog.Error(err), - ) - return false - } - filtered = filterChatsByWorkspaceID(chats, workspaceID) - } - if len(filtered) == 0 { - return false - } - - allHavePR := true - for _, chat := range filtered { - refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout) - status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true) - cancel() - if err != nil { - api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth", - slog.F("workspace_id", workspaceID), - slog.F("chat_id", chat.ID), - slog.Error(err), - ) - allHavePR = false - } else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" { - allHavePR = false - } - - api.publishChatStatusEvent(ctx, chat.ID) - api.publishChatDiffStatusEvent(ctx, chat.ID) - } - - return allHavePR -} - -func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat { - filteredChats := make([]database.Chat, 0, len(chats)) - for _, chat := range chats { - if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID { - continue - } - filteredChats = append(filteredChats, chat) - } - return filteredChats -} - -func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) { - if api.chatDaemon == nil { - return - } - - if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil { - api.Logger.Debug(ctx, "failed to refresh published chat status", - slog.F("chat_id", chatID), - slog.Error(err), - ) - } -} - -func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) { - if api.chatDaemon == nil { - return - } - - if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil { - api.Logger.Debug(ctx, "failed to publish chat diff status change", - slog.F("chat_id", chatID), - slog.Error(err), - ) - } -} - func (api *API) resolveChatDiffContents( ctx context.Context, chat database.Chat, @@ -1483,22 +1257,36 @@ func (api *API) resolveChatDiffContents( if reference.RepositoryRef == nil { return result, nil } - if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) { + + gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin) + if gp == nil { return result, nil } - token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID) + token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin) + if err != nil { + return result, xerrors.Errorf("resolve git access token: %w", err) + } else if token == nil { + return result, xerrors.New("nil git access token") + } if reference.PullRequestURL != "" { - diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token) + ref, ok := gp.ParsePullRequestURL(reference.PullRequestURL) + if !ok { + return result, xerrors.Errorf("invalid pull request URL %q", reference.PullRequestURL) + } + diff, err := gp.FetchPullRequestDiff(ctx, *token, ref) if err != nil { return result, err } result.Diff = diff return result, nil } - - diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token) + diff, err := gp.FetchBranchDiff(ctx, *token, gitprovider.BranchRef{ + Owner: reference.RepositoryRef.Owner, + Repo: reference.RepositoryRef.Repo, + Branch: reference.RepositoryRef.Branch, + }) if err != nil { return result, err } @@ -1532,34 +1320,53 @@ func (api *API) resolveChatDiffReference( // If we have a repo ref with a branch, try to resolve the // current open PR. This picks up new PRs after the previous // one was closed. - if reference.RepositoryRef != nil && - strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) { - pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef) - if lookupErr != nil { - api.Logger.Debug(ctx, "failed to resolve pull request from repository reference", - slog.F("chat_id", chat.ID), - slog.F("provider", reference.RepositoryRef.Provider), - slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin), - slog.F("branch", reference.RepositoryRef.Branch), - slog.Error(lookupErr), - ) - } else if pullRequestURL != "" { - reference.PullRequestURL = pullRequestURL + if reference.RepositoryRef != nil && reference.RepositoryRef.Owner != "" { + gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin) + if gp != nil { + token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin) + if token == nil || errors.Is(err, gitsync.ErrNoTokenAvailable) { + // No token available yet. + return reference, nil + } else if err != nil { + return chatDiffReference{}, xerrors.Errorf("resolve git access token: %w", err) + } + prRef, lookupErr := gp.ResolveBranchPullRequest(ctx, *token, gitprovider.BranchRef{ + Owner: reference.RepositoryRef.Owner, + Repo: reference.RepositoryRef.Repo, + Branch: reference.RepositoryRef.Branch, + }) + if lookupErr != nil { + api.Logger.Debug(ctx, "failed to resolve pull request from repository reference", + slog.F("chat_id", chat.ID), + slog.F("provider", reference.RepositoryRef.Provider), + slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin), + slog.F("branch", reference.RepositoryRef.Branch), + slog.Error(lookupErr), + ) + } else if prRef != nil { + reference.PullRequestURL = gp.BuildPullRequestURL(*prRef) + } + reference.PullRequestURL = gp.NormalizePullRequestURL(reference.PullRequestURL) } } - reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL) - // If we have a PR URL but no repo ref (e.g. the agent hasn't // reported branch/origin yet), derive a partial ref from the // PR URL so the caller can still show provider/owner/repo. if reference.RepositoryRef == nil && reference.PullRequestURL != "" { - if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok { - reference.RepositoryRef = &chatRepositoryRef{ - Provider: string(codersdk.EnhancedExternalAuthProviderGitHub), - RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo), - Owner: parsed.Owner, - Repo: parsed.Repo, + for _, extAuth := range api.ExternalAuthConfigs { + gp := extAuth.Git(api.HTTPClient) + if gp == nil { + continue + } + if parsed, ok := gp.ParsePullRequestURL(reference.PullRequestURL); ok { + reference.RepositoryRef = &chatRepositoryRef{ + Provider: strings.ToLower(extAuth.Type), + Owner: parsed.Owner, + Repo: parsed.Repo, + RemoteOrigin: gp.BuildRepositoryURL(parsed.Owner, parsed.Repo), + } + break } } } @@ -1577,19 +1384,18 @@ func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus) return nil } + providerType, gp := api.resolveExternalAuth(origin) repoRef := &chatRepositoryRef{ - Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)), + Provider: providerType, RemoteOrigin: origin, Branch: branch, } - - if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok { - if repoRef.Provider == "" { - repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub) + if gp != nil { + if owner, repo, normalizedOrigin, ok := gp.ParseRepositoryOrigin(repoRef.RemoteOrigin); ok { + repoRef.RemoteOrigin = normalizedOrigin + repoRef.Owner = owner + repoRef.Repo = repo } - repoRef.RemoteOrigin = normalizedOrigin - repoRef.Owner = owner - repoRef.Repo = repo } if repoRef.Provider == "" { @@ -1643,60 +1449,31 @@ func (api *API) getCachedChatDiffStatus( ) } -func (api *API) resolveExternalAuthProviderType(match string) string { - match = strings.TrimSpace(match) - if match == "" { - return "" +// resolveExternalAuth finds the external auth config matching the +// given remote origin URL and returns both the provider type string +// (e.g. "github") and the gitprovider.Provider. Returns ("", nil) +// if no matching config is found. +func (api *API) resolveExternalAuth(origin string) (providerType string, gp gitprovider.Provider) { + origin = strings.TrimSpace(origin) + if origin == "" { + return "", nil } - for _, extAuth := range api.ExternalAuthConfigs { - if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) { + if extAuth.Regex == nil || !extAuth.Regex.MatchString(origin) { continue } - return strings.ToLower(strings.TrimSpace(extAuth.Type)) + return strings.ToLower(strings.TrimSpace(extAuth.Type)), + extAuth.Git(api.HTTPClient) } - - return "" + return "", nil } -func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) { - raw = strings.TrimSpace(raw) - if raw == "" { - return "", "", "", false - } - - matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw) - if len(matches) != 3 { - matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw) - } - if len(matches) != 3 { - return "", "", "", false - } - - owner = strings.TrimSpace(matches[1]) - repo = strings.TrimSpace(matches[2]) - repo = strings.TrimSuffix(repo, ".git") - if owner == "" || repo == "" { - return "", "", "", false - } - - return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true -} - -func buildGitHubBranchURL(owner string, repo string, branch string) string { - owner = strings.TrimSpace(owner) - repo = strings.TrimSpace(repo) - branch = strings.TrimSpace(branch) - if owner == "" || repo == "" || branch == "" { - return "" - } - - return fmt.Sprintf( - "https://github.com/%s/%s/tree/%s", - owner, - repo, - url.PathEscape(branch), - ) +// resolveGitProvider finds the external auth config matching the +// given remote origin URL and returns its git provider. Returns +// nil if no matching git provider is configured. +func (api *API) resolveGitProvider(origin string) gitprovider.Provider { + _, gp := api.resolveExternalAuth(origin) + return gp } func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool { @@ -1712,11 +1489,32 @@ func (api *API) refreshChatDiffStatus( chatID uuid.UUID, pullRequestURL string, ) (database.ChatDiffStatus, error) { - status, err := api.fetchGitHubPullRequestStatus( - ctx, - pullRequestURL, - api.resolveChatGitHubAccessToken(ctx, chatOwnerID), - ) + // Find a provider that can handle this PR URL. + var gp gitprovider.Provider + var ref gitprovider.PRRef + for _, extAuth := range api.ExternalAuthConfigs { + p := extAuth.Git(api.HTTPClient) + if p == nil { + continue + } + if parsed, ok := p.ParsePullRequestURL(pullRequestURL); ok { + gp = p + ref = parsed + break + } + } + if gp == nil { + return database.ChatDiffStatus{}, xerrors.Errorf("no git provider found for PR URL %q", pullRequestURL) + } + + origin := gp.BuildRepositoryURL(ref.Owner, ref.Repo) + token, err := api.resolveChatGitAccessToken(ctx, chatOwnerID, origin) + if err != nil { + return database.ChatDiffStatus{}, xerrors.Errorf("resolve git access token: %w", err) + } else if token == nil { + return database.ChatDiffStatus{}, xerrors.New("nil git access token") + } + status, err := gp.FetchPullRequestStatus(ctx, *token, ref) if err != nil { return database.ChatDiffStatus{}, err } @@ -1728,13 +1526,13 @@ func (api *API) refreshChatDiffStatus( ChatID: chatID, Url: sql.NullString{String: pullRequestURL, Valid: true}, PullRequestState: sql.NullString{ - String: status.PullRequestState, - Valid: status.PullRequestState != "", + String: string(status.State), + Valid: status.State != "", }, ChangesRequested: status.ChangesRequested, - Additions: status.Additions, - Deletions: status.Deletions, - ChangedFiles: status.ChangedFiles, + Additions: status.DiffStats.Additions, + Deletions: status.DiffStats.Deletions, + ChangedFiles: status.DiffStats.ChangedFiles, RefreshedAt: refreshedAt, StaleAt: refreshedAt.Add(chatDiffStatusTTL), }, @@ -1745,23 +1543,49 @@ func (api *API) refreshChatDiffStatus( return refreshedStatus, nil } -func (api *API) resolveChatGitHubAccessToken( +func (api *API) resolveChatGitAccessToken( ctx context.Context, userID uuid.UUID, -) string { - // Build a map of provider ID -> config so we can refresh tokens - // using the same code path as provisionerdserver. - ghConfigs := make(map[string]*externalauth.Config) - providerIDs := []string{"github"} - for _, config := range api.ExternalAuthConfigs { - if !strings.EqualFold( - config.Type, - string(codersdk.EnhancedExternalAuthProviderGitHub), - ) { - continue + origin string, +) (*string, error) { + origin = strings.TrimSpace(origin) + + // If we have an origin, find the specific matching config first. + // This ensures multi-provider setups (github.com + GHE) get the + // correct token. + if origin != "" { + for _, config := range api.ExternalAuthConfigs { + if config.Regex == nil || !config.Regex.MatchString(origin) { + continue + } + link, err := api.Database.GetExternalAuthLink(ctx, + database.GetExternalAuthLinkParams{ + ProviderID: config.ID, + UserID: userID, + }, + ) + if err != nil { + continue + } + refreshed, refreshErr := config.RefreshToken(ctx, api.Database, link) + if refreshErr == nil { + link = refreshed + } + token := strings.TrimSpace(link.OAuthAccessToken) + if token != "" { + return ptr.Ref(token), nil + } } + } + + // Fallback: iterate all external auth configs. + // Used when origin is empty (inline refresh from HTTP handler) + // or when the origin-specific lookup above failed. + configs := make(map[string]*externalauth.Config) + providerIDs := []string{} + for _, config := range api.ExternalAuthConfigs { providerIDs = append(providerIDs, config.ID) - ghConfigs[config.ID] = config + configs[config.ID] = config } seen := map[string]struct{}{} @@ -1785,7 +1609,7 @@ func (api *API) resolveChatGitHubAccessToken( // Refresh the token if there is a matching config, mirroring // the same code path used by provisionerdserver when handing // tokens to provisioners. - if cfg, ok := ghConfigs[providerID]; ok { + if cfg, ok := configs[providerID]; ok { refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link) if refreshErr != nil { api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff", @@ -1802,336 +1626,11 @@ func (api *API) resolveChatGitHubAccessToken( token := strings.TrimSpace(link.OAuthAccessToken) if token != "" { - return token + return ptr.Ref(token), nil } } - return "" -} - -func (api *API) resolveGitHubPullRequestURLFromRepositoryRef( - ctx context.Context, - userID uuid.UUID, - repositoryRef chatRepositoryRef, -) (string, error) { - if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" { - return "", nil - } - - query := url.Values{} - query.Set("state", "open") - query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch)) - query.Set("sort", "updated") - query.Set("direction", "desc") - query.Set("per_page", "1") - - requestURL := fmt.Sprintf( - "%s/repos/%s/%s/pulls?%s", - githubAPIBaseURL, - repositoryRef.Owner, - repositoryRef.Repo, - query.Encode(), - ) - - var pulls []struct { - HTMLURL string `json:"html_url"` - } - - token := api.resolveChatGitHubAccessToken(ctx, userID) - if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil { - return "", err - } - if len(pulls) == 0 { - return "", nil - } - - return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil -} - -func (api *API) fetchGitHubPullRequestDiff( - ctx context.Context, - pullRequestURL string, - token string, -) (string, error) { - ref, ok := parseGitHubPullRequestURL(pullRequestURL) - if !ok { - return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL) - } - - requestURL := fmt.Sprintf( - "%s/repos/%s/%s/pulls/%d", - githubAPIBaseURL, - ref.Owner, - ref.Repo, - ref.Number, - ) - - return api.fetchGitHubDiff(ctx, requestURL, token) -} - -func (api *API) fetchGitHubCompareDiff( - ctx context.Context, - repositoryRef chatRepositoryRef, - token string, -) (string, error) { - if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" { - return "", nil - } - - var repository struct { - DefaultBranch string `json:"default_branch"` - } - - repositoryURL := fmt.Sprintf( - "%s/repos/%s/%s", - githubAPIBaseURL, - repositoryRef.Owner, - repositoryRef.Repo, - ) - if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil { - return "", err - } - defaultBranch := strings.TrimSpace(repository.DefaultBranch) - if defaultBranch == "" { - return "", xerrors.New("github repository default branch is empty") - } - - requestURL := fmt.Sprintf( - "%s/repos/%s/%s/compare/%s...%s", - githubAPIBaseURL, - repositoryRef.Owner, - repositoryRef.Repo, - url.PathEscape(defaultBranch), - url.PathEscape(repositoryRef.Branch), - ) - - return api.fetchGitHubDiff(ctx, requestURL, token) -} - -func (api *API) fetchGitHubDiff( - ctx context.Context, - requestURL string, - token string, -) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) - if err != nil { - return "", xerrors.Errorf("create github diff request: %w", err) - } - req.Header.Set("Accept", "application/vnd.github.diff") - req.Header.Set("X-GitHub-Api-Version", "2022-11-28") - req.Header.Set("User-Agent", "coder-chat-diff") - if token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - - httpClient := api.HTTPClient - if httpClient == nil { - httpClient = http.DefaultClient - } - - resp, err := httpClient.Do(req) - if err != nil { - return "", xerrors.Errorf("execute github diff request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) - if readErr != nil { - return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode) - } - return "", xerrors.Errorf( - "github diff request failed with status %d: %s", - resp.StatusCode, - strings.TrimSpace(string(body)), - ) - } - - diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) - if err != nil { - return "", xerrors.Errorf("read github diff response: %w", err) - } - return string(diff), nil -} - -func (api *API) fetchGitHubPullRequestStatus( - ctx context.Context, - pullRequestURL string, - token string, -) (githubPullRequestStatus, error) { - ref, ok := parseGitHubPullRequestURL(pullRequestURL) - if !ok { - return githubPullRequestStatus{}, xerrors.Errorf( - "invalid GitHub pull request URL %q", - pullRequestURL, - ) - } - - pullEndpoint := fmt.Sprintf( - "%s/repos/%s/%s/pulls/%d", - githubAPIBaseURL, - ref.Owner, - ref.Repo, - ref.Number, - ) - - var pull struct { - State string `json:"state"` - Additions int32 `json:"additions"` - Deletions int32 `json:"deletions"` - ChangedFiles int32 `json:"changed_files"` - } - if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil { - return githubPullRequestStatus{}, err - } - - var reviews []struct { - ID int64 `json:"id"` - State string `json:"state"` - User struct { - Login string `json:"login"` - } `json:"user"` - } - if err := api.decodeGitHubJSON( - ctx, - pullEndpoint+"/reviews?per_page=100", - token, - &reviews, - ); err != nil { - return githubPullRequestStatus{}, err - } - - return githubPullRequestStatus{ - PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)), - ChangesRequested: hasOutstandingGitHubChangesRequested(reviews), - Additions: pull.Additions, - Deletions: pull.Deletions, - ChangedFiles: pull.ChangedFiles, - }, nil -} - -func (api *API) decodeGitHubJSON( - ctx context.Context, - requestURL string, - token string, - dest any, -) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) - if err != nil { - return xerrors.Errorf("create github request: %w", err) - } - req.Header.Set("Accept", "application/vnd.github+json") - req.Header.Set("X-GitHub-Api-Version", "2022-11-28") - req.Header.Set("User-Agent", "coder-chat-diff-status") - if token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - - httpClient := api.HTTPClient - if httpClient == nil { - httpClient = http.DefaultClient - } - - resp, err := httpClient.Do(req) - if err != nil { - return xerrors.Errorf("execute github request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) - if readErr != nil { - return xerrors.Errorf( - "github request failed with status %d", - resp.StatusCode, - ) - } - return xerrors.Errorf( - "github request failed with status %d: %s", - resp.StatusCode, - strings.TrimSpace(string(body)), - ) - } - - if err := json.NewDecoder(resp.Body).Decode(dest); err != nil { - return xerrors.Errorf("decode github response: %w", err) - } - return nil -} - -func hasOutstandingGitHubChangesRequested( - reviews []struct { - ID int64 `json:"id"` - State string `json:"state"` - User struct { - Login string `json:"login"` - } `json:"user"` - }, -) bool { - type reviewerState struct { - reviewID int64 - state string - } - - statesByReviewer := make(map[string]reviewerState) - for _, review := range reviews { - login := strings.ToLower(strings.TrimSpace(review.User.Login)) - if login == "" { - continue - } - - state := strings.ToUpper(strings.TrimSpace(review.State)) - switch state { - case "CHANGES_REQUESTED", "APPROVED", "DISMISSED": - default: - continue - } - - current, exists := statesByReviewer[login] - if exists && current.reviewID > review.ID { - continue - } - statesByReviewer[login] = reviewerState{ - reviewID: review.ID, - state: state, - } - } - - for _, state := range statesByReviewer { - if state.state == "CHANGES_REQUESTED" { - return true - } - } - return false -} - -func normalizeGitHubPullRequestURL(raw string) string { - ref, ok := parseGitHubPullRequestURL(strings.TrimRight( - strings.TrimSpace(raw), - "),.;", - )) - if !ok { - return "" - } - return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number) -} - -func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) { - matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw)) - if len(matches) != 4 { - return githubPullRequestRef{}, false - } - - number, err := strconv.Atoi(matches[3]) - if err != nil { - return githubPullRequestRef{}, false - } - - return githubPullRequestRef{ - Owner: matches[1], - Repo: matches[2], - Number: number, - }, true + return nil, gitsync.ErrNoTokenAvailable } type createChatWorkspaceSelection struct { @@ -2786,11 +2285,21 @@ func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) co } } if result.URL == nil { - owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin) - if ok { - branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch) - if branchURL != "" { - result.URL = &branchURL + // Try to build a branch URL from the stored origin. + // Since convertChatDiffStatus does not have access to + // the API instance, we construct a GitHub provider + // directly as a best-effort fallback. + // TODO: This uses the default github.com API base URL, + // so branch URLs for GitHub Enterprise instances will + // be incorrect. To fix this, convertChatDiffStatus + // would need access to the external auth configs. + gp := gitprovider.New("github", "", nil) + if gp != nil { + if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok { + branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch) + if branchURL != "" { + result.URL = &branchURL + } } } } diff --git a/coderd/chats_test.go b/coderd/chats_test.go index fe5321b960..b016c6bf9d 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -2605,7 +2605,7 @@ func TestGetChatDiffStatus(t *testing.T) { require.NoError(t, err) require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID) require.NotNil(t, cachedStatus.URL) - require.Equal(t, "https://github.com/coder/coder/tree/feature%2Fdiff-status", *cachedStatus.URL) + require.Equal(t, "https://github.com/coder/coder/tree/feature/diff-status", *cachedStatus.URL) require.NotNil(t, cachedStatus.PullRequestState) require.Equal(t, "open", *cachedStatus.PullRequestState) require.True(t, cachedStatus.ChangesRequested) diff --git a/coderd/coderd.go b/coderd/coderd.go index 9d2a25c360..f1c842639d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -61,6 +61,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/files" "github.com/coder/coder/v2/coderd/gitsshkey" + "github.com/coder/coder/v2/coderd/gitsync" "github.com/coder/coder/v2/coderd/healthcheck" "github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/httpapi" @@ -773,6 +774,21 @@ func New(options *Options) *API { Pubsub: options.Pubsub, WebpushDispatcher: options.WebPushDispatcher, }) + gitSyncLogger := options.Logger.Named("gitsync") + refresher := gitsync.NewRefresher( + api.resolveGitProvider, + api.resolveChatGitAccessToken, + gitSyncLogger.Named("refresher"), + quartz.NewReal(), + ) + api.gitSyncWorker = gitsync.NewWorker(options.Database, + refresher, + api.chatDaemon.PublishDiffStatusChange, + quartz.NewReal(), + gitSyncLogger, + ) + // nolint:gocritic // chat diff worker needs to be able to CRUD chats. + go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx)) if options.DeploymentValues.Prometheus.Enable { options.PrometheusRegistry.MustRegister(stn) api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry) @@ -1999,6 +2015,9 @@ type API struct { dbRolluper *dbrollup.Rolluper // chatDaemon handles background processing of pending chats. chatDaemon *chatd.Server + // gitSyncWorker refreshes stale chat diff statuses in the + // background. + gitSyncWorker *gitsync.Worker } // Close waits for all WebSocket connections to drain before returning. @@ -2028,6 +2047,13 @@ func (api *API) Close() error { api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds") } api.dbRolluper.Close() + // chatDiffWorker is unconditionally initialized in New(). + select { + case <-api.gitSyncWorker.Done(): + case <-time.After(10 * time.Second): + api.Logger.Warn(context.Background(), + "chat diff refresh worker did not exit in time") + } if err := api.chatDaemon.Close(); err != nil { api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err)) } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a6f0a47e3e..e48a455348 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1539,6 +1539,17 @@ func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.Acquir return q.db.AcquireProvisionerJob(ctx, arg) } +func (q *querier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + // This is a system-level batch operation used by the gitsync + // background worker. Per-object authorization is impractical + // for a SKIP LOCKED acquisition query; callers must use + // AsChatd context. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil { + return nil, err + } + return q.db.AcquireStaleChatDiffStatuses(ctx, limitVal) +} + func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { fetch := func(ctx context.Context, arg database.ActivityBumpWorkspaceParams) (database.Workspace, error) { return q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) @@ -1577,6 +1588,16 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas return q.db.ArchiveUnusedTemplateVersions(ctx, arg) } +func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error { + // This is a system-level operation used by the gitsync + // background worker to reschedule failed refreshes. Same + // authorization pattern as AcquireStaleChatDiffStatuses. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil { + return err + } + return q.db.BackoffChatDiffStatus(ctx, arg) +} + func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { // Could be any workspace agent and checking auth to each workspace agent is overkill for // the purpose of this function. diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f5b16a76c0..6af3477f60 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -770,6 +770,18 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus) })) + s.Run("AcquireStaleChatDiffStatuses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), int32(10)).Return([]database.AcquireStaleChatDiffStatusesRow{}, nil).AnyTimes() + check.Args(int32(10)).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.AcquireStaleChatDiffStatusesRow{}) + })) + s.Run("BackoffChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.BackoffChatDiffStatusParams{ + ChatID: uuid.New(), + StaleAt: dbtime.Now(), + } + dbm.EXPECT().BackoffChatDiffStatus(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns() + })) s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 8436e6db9b..46adaeec89 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -136,6 +136,14 @@ func (m queryMetricsStore) AcquireProvisionerJob(ctx context.Context, arg databa return r0, r1 } +func (m queryMetricsStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + start := time.Now() + r0, r1 := m.s.AcquireStaleChatDiffStatuses(ctx, limitVal) + m.queryLatencies.WithLabelValues("AcquireStaleChatDiffStatuses").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireStaleChatDiffStatuses").Inc() + return r0, r1 +} + func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { start := time.Now() r0 := m.s.ActivityBumpWorkspace(ctx, arg) @@ -168,6 +176,14 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar return r0, r1 } +func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error { + start := time.Now() + r0 := m.s.BackoffChatDiffStatus(ctx, arg) + m.queryLatencies.WithLabelValues("BackoffChatDiffStatus").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BackoffChatDiffStatus").Inc() + return r0 +} + func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { start := time.Now() r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f914011018..cd3ad0d843 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -103,6 +103,21 @@ func (mr *MockStoreMockRecorder) AcquireProvisionerJob(ctx, arg any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireProvisionerJob", reflect.TypeOf((*MockStore)(nil).AcquireProvisionerJob), ctx, arg) } +// AcquireStaleChatDiffStatuses mocks base method. +func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal) + ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses. +func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal) +} + // ActivityBumpWorkspace mocks base method. func (m *MockStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { m.ctrl.T.Helper() @@ -161,6 +176,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg) } +// BackoffChatDiffStatus mocks base method. +func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus. +func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg) +} + // BatchUpdateWorkspaceAgentMetadata mocks base method. func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 84bec69a61..98964b54a5 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -39,6 +39,7 @@ type sqlcQuerier interface { // multiple provisioners from acquiring the same jobs. See: // https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) + AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error) // Bumps the workspace deadline by the template's configured "activity_bump" // duration (default 1h). If the workspace bump will cross an autostart // threshold, then the bump is autostart + TTL. This is the deadline behavior if @@ -60,6 +61,7 @@ type sqlcQuerier interface { // Only unused template versions will be archived, which are any versions not // referenced by the latest build of a workspace. ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) + BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9044d303b0..b3262a792d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3026,6 +3026,102 @@ func (q *sqlQuerier) AcquireChat(ctx context.Context, arg AcquireChatParams) (Ch return i, err } +const acquireStaleChatDiffStatuses = `-- name: AcquireStaleChatDiffStatuses :many +WITH acquired AS ( + UPDATE + chat_diff_statuses + SET + -- Claim for 5 minutes. The worker sets the real stale_at + -- after refresh. If the worker crashes, rows become eligible + -- again after this interval. + stale_at = NOW() + INTERVAL '5 minutes', + updated_at = NOW() + WHERE + chat_id IN ( + SELECT + cds.chat_id + FROM + chat_diff_statuses cds + INNER JOIN + chats c ON c.id = cds.chat_id + WHERE + cds.stale_at <= NOW() + AND cds.git_remote_origin != '' + AND cds.git_branch != '' + AND c.archived = FALSE + ORDER BY + cds.stale_at ASC + FOR UPDATE OF cds + SKIP LOCKED + LIMIT + $1::int + ) + RETURNING chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin +) +SELECT + acquired.chat_id, acquired.url, acquired.pull_request_state, acquired.changes_requested, acquired.additions, acquired.deletions, acquired.changed_files, acquired.refreshed_at, acquired.stale_at, acquired.created_at, acquired.updated_at, acquired.git_branch, acquired.git_remote_origin, + c.owner_id +FROM + acquired +INNER JOIN + chats c ON c.id = acquired.chat_id +` + +type AcquireStaleChatDiffStatusesRow struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Url sql.NullString `db:"url" json:"url"` + PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"` + ChangesRequested bool `db:"changes_requested" json:"changes_requested"` + Additions int32 `db:"additions" json:"additions"` + Deletions int32 `db:"deletions" json:"deletions"` + ChangedFiles int32 `db:"changed_files" json:"changed_files"` + RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"` + StaleAt time.Time `db:"stale_at" json:"stale_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + GitBranch string `db:"git_branch" json:"git_branch"` + GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` +} + +func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error) { + rows, err := q.db.QueryContext(ctx, acquireStaleChatDiffStatuses, limitVal) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AcquireStaleChatDiffStatusesRow + for rows.Next() { + var i AcquireStaleChatDiffStatusesRow + if err := rows.Scan( + &i.ChatID, + &i.Url, + &i.PullRequestState, + &i.ChangesRequested, + &i.Additions, + &i.Deletions, + &i.ChangedFiles, + &i.RefreshedAt, + &i.StaleAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.GitBranch, + &i.GitRemoteOrigin, + &i.OwnerID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const archiveChatByID = `-- name: ArchiveChatByID :exec UPDATE chats SET archived = true, updated_at = NOW() WHERE id = $1 OR root_chat_id = $1 @@ -3036,6 +3132,26 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error { return err } +const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec +UPDATE + chat_diff_statuses +SET + stale_at = $1::timestamptz, + updated_at = NOW() +WHERE + chat_id = $2::uuid +` + +type BackoffChatDiffStatusParams struct { + StaleAt time.Time `db:"stale_at" json:"stale_at"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` +} + +func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error { + _, err := q.db.ExecContext(ctx, backoffChatDiffStatus, arg.StaleAt, arg.ChatID) + return err +} + const deleteAllChatQueuedMessages = `-- name: DeleteAllChatQueuedMessages :exec DELETE FROM chat_queued_messages WHERE chat_id = $1 ` diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index d4d73fc136..70e90ffa77 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -448,3 +448,52 @@ LIMIT -- name: GetChatByIDForUpdate :one SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE; + +-- name: AcquireStaleChatDiffStatuses :many +WITH acquired AS ( + UPDATE + chat_diff_statuses + SET + -- Claim for 5 minutes. The worker sets the real stale_at + -- after refresh. If the worker crashes, rows become eligible + -- again after this interval. + stale_at = NOW() + INTERVAL '5 minutes', + updated_at = NOW() + WHERE + chat_id IN ( + SELECT + cds.chat_id + FROM + chat_diff_statuses cds + INNER JOIN + chats c ON c.id = cds.chat_id + WHERE + cds.stale_at <= NOW() + AND cds.git_remote_origin != '' + AND cds.git_branch != '' + AND c.archived = FALSE + ORDER BY + cds.stale_at ASC + FOR UPDATE OF cds + SKIP LOCKED + LIMIT + @limit_val::int + ) + RETURNING * +) +SELECT + acquired.*, + c.owner_id +FROM + acquired +INNER JOIN + chats c ON c.id = acquired.chat_id; + +-- name: BackoffChatDiffStatus :exec +UPDATE + chat_diff_statuses +SET + stale_at = @stale_at::timestamptz, + updated_at = NOW() +WHERE + chat_id = @chat_id::uuid; diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 67923d18c2..2572154131 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -23,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" @@ -82,6 +83,10 @@ type Config struct { // a Git clone. e.g. "Username for 'https://github.com':" // The regex would be `github\.com`.. Regex *regexp.Regexp + // APIBaseURL is the base URL for provider REST API calls + // (e.g., "https://api.github.com" for GitHub). Derived from + // defaults when not explicitly configured. + APIBaseURL string // AppInstallURL is for GitHub App's (and hopefully others eventually) // to provide a link to install the app. There's installation // of the application, and user authentication. It's possible @@ -106,12 +111,23 @@ type Config struct { CodeChallengeMethodsSupported []promoauth.Oauth2PKCEChallengeMethod } +// Git returns a Provider for this config if the provider type +// is a supported git hosting provider. Returns nil for non-git +// providers (e.g. Slack, JFrog). +func (c *Config) Git(client *http.Client) gitprovider.Provider { + norm := strings.ToLower(c.Type) + if !codersdk.EnhancedExternalAuthProvider(norm).Git() { + return nil + } + return gitprovider.New(norm, c.APIBaseURL, client) +} + // GenerateTokenExtra generates the extra token data to store in the database. func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) { if len(c.ExtraTokenKeys) == 0 { return pqtype.NullRawMessage{}, nil } - extraMap := map[string]interface{}{} + extraMap := map[string]any{} for _, key := range c.ExtraTokenKeys { extraMap[key] = token.Extra(key) } @@ -730,6 +746,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut ClientID: entry.ClientID, ClientSecret: entry.ClientSecret, Regex: regex, + APIBaseURL: entry.APIBaseURL, Type: entry.Type, NoRefresh: entry.NoRefresh, ValidateURL: entry.ValidateURL, @@ -766,7 +783,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut // applyDefaultsToConfig applies defaults to the config entry. func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) { - configType := codersdk.EnhancedExternalAuthProvider(config.Type) + configType := codersdk.EnhancedExternalAuthProvider(strings.ToLower(config.Type)) if configType == "bitbucket" { // For backwards compatibility, we need to support the "bitbucket" string. configType = codersdk.EnhancedExternalAuthProviderBitBucketCloud @@ -783,7 +800,7 @@ func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) { } // Dynamic defaults - switch codersdk.EnhancedExternalAuthProvider(config.Type) { + switch configType { case codersdk.EnhancedExternalAuthProviderGitHub: copyDefaultSettings(config, gitHubDefaults(config)) return @@ -864,6 +881,19 @@ func copyDefaultSettings(config *codersdk.ExternalAuthConfig, defaults codersdk. if config.CodeChallengeMethodsSupported == nil { config.CodeChallengeMethodsSupported = []string{string(promoauth.PKCEChallengeMethodSha256)} } + + // Set default API base URL for providers that need one. + if config.APIBaseURL == "" { + normType := strings.ToLower(config.Type) + switch codersdk.EnhancedExternalAuthProvider(normType) { + case codersdk.EnhancedExternalAuthProviderGitHub: + config.APIBaseURL = "https://api.github.com" + case codersdk.EnhancedExternalAuthProviderGitLab: + config.APIBaseURL = "https://gitlab.com/api/v4" + case codersdk.EnhancedExternalAuthProviderGitea: + config.APIBaseURL = "https://gitea.com/api/v1" + } + } } // gitHubDefaults returns default config values for GitHub. diff --git a/coderd/externalauth/externalauth_internal_test.go b/coderd/externalauth/externalauth_internal_test.go index f88299412e..d845d92a86 100644 --- a/coderd/externalauth/externalauth_internal_test.go +++ b/coderd/externalauth/externalauth_internal_test.go @@ -25,6 +25,7 @@ func TestGitlabDefaults(t *testing.T) { DisplayName: "GitLab", DisplayIcon: "/icon/gitlab.svg", Regex: `^(https?://)?gitlab\.com(/.*)?$`, + APIBaseURL: "https://gitlab.com/api/v4", Scopes: []string{"write_repository"}, CodeChallengeMethodsSupported: []string{string(promoauth.PKCEChallengeMethodSha256)}, } diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index d04634670c..daf5927e21 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -845,6 +845,40 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext return fake, config, link } +func TestApplyDefaultsToConfig_CaseInsensitive(t *testing.T) { + t.Parallel() + + instrument := promoauth.NewFactory(prometheus.NewRegistry()) + accessURL, err := url.Parse("https://coder.example.com") + require.NoError(t, err) + + for _, tc := range []struct { + Name string + Type string + }{ + {Name: "GitHub", Type: "GitHub"}, + {Name: "GITLAB", Type: "GITLAB"}, + {Name: "Gitea", Type: "Gitea"}, + } { + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + configs, err := externalauth.ConvertConfig( + instrument, + []codersdk.ExternalAuthConfig{{ + Type: tc.Type, + ClientID: "test-id", + ClientSecret: "test-secret", + }}, + accessURL, + ) + require.NoError(t, err) + require.Len(t, configs, 1) + // Defaults should have been applied despite mixed-case Type. + assert.NotEmpty(t, configs[0].AuthCodeURL("state"), "auth URL should be populated from defaults") + }) + } +} + type roundTripper func(req *http.Request) (*http.Response, error) func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/coderd/externalauth/gitprovider/github.go b/coderd/externalauth/gitprovider/github.go new file mode 100644 index 0000000000..bced64ad39 --- /dev/null +++ b/coderd/externalauth/gitprovider/github.go @@ -0,0 +1,540 @@ +package gitprovider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/quartz" +) + +const ( + defaultGitHubAPIBaseURL = "https://api.github.com" + // Adding padding to our retry times to guard against over-consumption of request quotas. + RateLimitPadding = 5 * time.Minute +) + +type githubProvider struct { + apiBaseURL string + webBaseURL string + httpClient *http.Client + clock quartz.Clock + + // Compiled per-instance to support GitHub Enterprise hosts. + pullRequestPathPattern *regexp.Regexp + repositoryHTTPSPattern *regexp.Regexp + repositorySSHPathPattern *regexp.Regexp +} + +func newGitHub(apiBaseURL string, httpClient *http.Client, clock quartz.Clock) *githubProvider { + if apiBaseURL == "" { + apiBaseURL = defaultGitHubAPIBaseURL + } + apiBaseURL = strings.TrimRight(apiBaseURL, "/") + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Derive the web base URL from the API base URL. + // github.com: api.github.com → github.com + // GHE: ghes.corp.com/api/v3 → ghes.corp.com + webBaseURL := deriveWebBaseURL(apiBaseURL) + + // Parse the host for regex construction. + host := extractHost(webBaseURL) + + // Escape the host for use in regex patterns. + escapedHost := regexp.QuoteMeta(host) + + return &githubProvider{ + apiBaseURL: apiBaseURL, + webBaseURL: webBaseURL, + httpClient: httpClient, + clock: clock, + pullRequestPathPattern: regexp.MustCompile( + `^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`, + ), + repositoryHTTPSPattern: regexp.MustCompile( + `^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`, + ), + repositorySSHPathPattern: regexp.MustCompile( + `^(?:ssh://)?git@` + escapedHost + `[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`, + ), + } +} + +// deriveWebBaseURL converts a GitHub API base URL to the +// corresponding web base URL. +// +// github.com: https://api.github.com → https://github.com +// GHE: https://ghes.corp.com/api/v3 → https://ghes.corp.com +func deriveWebBaseURL(apiBaseURL string) string { + u, err := url.Parse(apiBaseURL) + if err != nil { + return "https://github.com" + } + + // Standard github.com: API host is api.github.com. + if strings.EqualFold(u.Host, "api.github.com") { + return "https://github.com" + } + + // GHE: strip /api/v3 path suffix. + u.Path = strings.TrimSuffix(u.Path, "/api/v3") + u.Path = strings.TrimSuffix(u.Path, "/") + return u.String() +} + +// extractHost returns the host portion of a URL. +func extractHost(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "github.com" + } + return u.Host +} + +func (g *githubProvider) ParseRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", "", false + } + + matches := g.repositoryHTTPSPattern.FindStringSubmatch(raw) + if len(matches) != 3 { + matches = g.repositorySSHPathPattern.FindStringSubmatch(raw) + } + if len(matches) != 3 { + return "", "", "", false + } + + owner = strings.TrimSpace(matches[1]) + repo = strings.TrimSpace(matches[2]) + repo = strings.TrimSuffix(repo, ".git") + if owner == "" || repo == "" { + return "", "", "", false + } + + return owner, repo, fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo)), true +} + +func (g *githubProvider) ParsePullRequestURL(raw string) (PRRef, bool) { + matches := g.pullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw)) + if len(matches) != 4 { + return PRRef{}, false + } + + number, err := strconv.Atoi(matches[3]) + if err != nil { + return PRRef{}, false + } + + return PRRef{ + Owner: matches[1], + Repo: matches[2], + Number: number, + }, true +} + +func (g *githubProvider) NormalizePullRequestURL(raw string) string { + ref, ok := g.ParsePullRequestURL(strings.TrimRight( + strings.TrimSpace(raw), + "),.;", + )) + if !ok { + return "" + } + return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number) +} + +// escapePathPreserveSlashes escapes each segment of a path +// individually, preserving `/` separators. This is needed for +// web URLs where GitHub expects literal slashes (e.g. +// /tree/feat/new-thing). +func escapePathPreserveSlashes(s string) string { + segments := strings.Split(s, "/") + for i, seg := range segments { + segments[i] = url.PathEscape(seg) + } + return strings.Join(segments, "/") +} + +func (g *githubProvider) BuildBranchURL(owner string, repo string, branch string) string { + owner = strings.TrimSpace(owner) + repo = strings.TrimSpace(repo) + branch = strings.TrimSpace(branch) + if owner == "" || repo == "" || branch == "" { + return "" + } + + return fmt.Sprintf( + "%s/%s/%s/tree/%s", + g.webBaseURL, + url.PathEscape(owner), + url.PathEscape(repo), + escapePathPreserveSlashes(branch), + ) +} + +func (g *githubProvider) BuildRepositoryURL(owner string, repo string) string { + owner = strings.TrimSpace(owner) + repo = strings.TrimSpace(repo) + if owner == "" || repo == "" { + return "" + } + return fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo)) +} + +func (g *githubProvider) BuildPullRequestURL(ref PRRef) string { + if ref.Owner == "" || ref.Repo == "" || ref.Number <= 0 { + return "" + } + return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number) +} + +func (g *githubProvider) ResolveBranchPullRequest( + ctx context.Context, + token string, + ref BranchRef, +) (*PRRef, error) { + if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" { + return nil, nil + } + + query := url.Values{} + query.Set("state", "open") + query.Set("head", fmt.Sprintf("%s:%s", ref.Owner, ref.Branch)) + query.Set("sort", "updated") + query.Set("direction", "desc") + query.Set("per_page", "1") + + requestURL := fmt.Sprintf( + "%s/repos/%s/%s/pulls?%s", + g.apiBaseURL, + url.PathEscape(ref.Owner), + url.PathEscape(ref.Repo), + query.Encode(), + ) + + var pulls []struct { + HTMLURL string `json:"html_url"` + Number int `json:"number"` + } + + if err := g.decodeJSON(ctx, requestURL, token, &pulls); err != nil { + return nil, err + } + if len(pulls) == 0 { + return nil, nil + } + + prRef, ok := g.ParsePullRequestURL(pulls[0].HTMLURL) + if !ok { + return nil, nil + } + return &prRef, nil +} + +func (g *githubProvider) FetchPullRequestStatus( + ctx context.Context, + token string, + ref PRRef, +) (*PRStatus, error) { + pullEndpoint := fmt.Sprintf( + "%s/repos/%s/%s/pulls/%d", + g.apiBaseURL, + url.PathEscape(ref.Owner), + url.PathEscape(ref.Repo), + ref.Number, + ) + + var pull struct { + State string `json:"state"` + Merged bool `json:"merged"` + Draft bool `json:"draft"` + Additions int32 `json:"additions"` + Deletions int32 `json:"deletions"` + ChangedFiles int32 `json:"changed_files"` + Head struct { + SHA string `json:"sha"` + } `json:"head"` + } + if err := g.decodeJSON(ctx, pullEndpoint, token, &pull); err != nil { + return nil, err + } + + var reviews []struct { + ID int64 `json:"id"` + State string `json:"state"` + User struct { + Login string `json:"login"` + } `json:"user"` + } + // GitHub returns at most 100 reviews per page. We do not + // paginate because PRs with >100 reviews are extremely rare, + // and the cost of multiple API calls per refresh is not + // justified. If needed, pagination can be added later. + if err := g.decodeJSON( + ctx, + pullEndpoint+"/reviews?per_page=100", + token, + &reviews, + ); err != nil { + return nil, err + } + + state := PRState(strings.ToLower(strings.TrimSpace(pull.State))) + if pull.Merged { + state = PRStateMerged + } + + return &PRStatus{ + State: state, + Draft: pull.Draft, + HeadSHA: pull.Head.SHA, + DiffStats: DiffStats{ + Additions: pull.Additions, + Deletions: pull.Deletions, + ChangedFiles: pull.ChangedFiles, + }, + ChangesRequested: hasOutstandingChangesRequested(reviews), + FetchedAt: g.clock.Now().UTC(), + }, nil +} + +func (g *githubProvider) FetchPullRequestDiff( + ctx context.Context, + token string, + ref PRRef, +) (string, error) { + requestURL := fmt.Sprintf( + "%s/repos/%s/%s/pulls/%d", + g.apiBaseURL, + url.PathEscape(ref.Owner), + url.PathEscape(ref.Repo), + ref.Number, + ) + return g.fetchDiff(ctx, requestURL, token) +} + +func (g *githubProvider) FetchBranchDiff( + ctx context.Context, + token string, + ref BranchRef, +) (string, error) { + if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" { + return "", nil + } + + var repository struct { + DefaultBranch string `json:"default_branch"` + } + + repositoryURL := fmt.Sprintf( + "%s/repos/%s/%s", + g.apiBaseURL, + url.PathEscape(ref.Owner), + url.PathEscape(ref.Repo), + ) + if err := g.decodeJSON(ctx, repositoryURL, token, &repository); err != nil { + return "", err + } + defaultBranch := strings.TrimSpace(repository.DefaultBranch) + if defaultBranch == "" { + return "", xerrors.New("github repository default branch is empty") + } + + requestURL := fmt.Sprintf( + "%s/repos/%s/%s/compare/%s...%s", + g.apiBaseURL, + url.PathEscape(ref.Owner), + url.PathEscape(ref.Repo), + url.PathEscape(defaultBranch), + url.PathEscape(ref.Branch), + ) + + return g.fetchDiff(ctx, requestURL, token) +} + +func (g *githubProvider) decodeJSON( + ctx context.Context, + requestURL string, + token string, + dest any, +) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return xerrors.Errorf("create github request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("X-GitHub-Api-Version", "2022-11-28") + req.Header.Set("User-Agent", "coder-chat-diff-status") + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := g.httpClient.Do(req) + if err != nil { + return xerrors.Errorf("execute github request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests { + retryAfter := ParseRetryAfter(resp.Header, g.clock) + if retryAfter > 0 { + return &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter + RateLimitPadding)} + } + // No rate-limit headers — fall through to generic error. + } + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) + if readErr != nil { + return xerrors.Errorf( + "github request failed with status %d", + resp.StatusCode, + ) + } + return xerrors.Errorf( + "github request failed with status %d: %s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) + } + + if err := json.NewDecoder(resp.Body).Decode(dest); err != nil { + return xerrors.Errorf("decode github response: %w", err) + } + return nil +} + +func (g *githubProvider) fetchDiff( + ctx context.Context, + requestURL string, + token string, +) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return "", xerrors.Errorf("create github diff request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github.diff") + req.Header.Set("X-GitHub-Api-Version", "2022-11-28") + req.Header.Set("User-Agent", "coder-chat-diff") + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := g.httpClient.Do(req) + if err != nil { + return "", xerrors.Errorf("execute github diff request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests { + retryAfter := ParseRetryAfter(resp.Header, g.clock) + if retryAfter > 0 { + return "", &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter + RateLimitPadding)} + } + } + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192)) + if readErr != nil { + return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode) + } + return "", xerrors.Errorf( + "github diff request failed with status %d: %s", + resp.StatusCode, + strings.TrimSpace(string(body)), + ) + } + + // Read one extra byte beyond MaxDiffSize so we can detect + // whether the diff exceeds the limit. LimitReader stops us + // allocating an arbitrarily large buffer by accident. + buf, err := io.ReadAll(io.LimitReader(resp.Body, MaxDiffSize+1)) + if err != nil { + return "", xerrors.Errorf("read github diff response: %w", err) + } + if len(buf) > MaxDiffSize { + return "", ErrDiffTooLarge + } + return string(buf), nil +} + +// ParseRetryAfter extracts a retry-after time from GitHub +// rate-limit headers. Returns zero value if no recognizable header is +// present. +func ParseRetryAfter(h http.Header, clk quartz.Clock) time.Duration { + if clk == nil { + clk = quartz.NewReal() + } + // Retry-After header: seconds until retry. + if ra := h.Get("Retry-After"); ra != "" { + if secs, err := strconv.Atoi(ra); err == nil { + return time.Duration(secs) * time.Second + } + } + // X-Ratelimit-Reset header: unix timestamp. We compute the + // duration from now according to the caller's clock. + if reset := h.Get("X-Ratelimit-Reset"); reset != "" { + if ts, err := strconv.ParseInt(reset, 10, 64); err == nil { + d := time.Unix(ts, 0).Sub(clk.Now()) + return d + } + } + return 0 +} + +func hasOutstandingChangesRequested( + reviews []struct { + ID int64 `json:"id"` + State string `json:"state"` + User struct { + Login string `json:"login"` + } `json:"user"` + }, +) bool { + type reviewerState struct { + reviewID int64 + state string + } + + statesByReviewer := make(map[string]reviewerState) + for _, review := range reviews { + login := strings.ToLower(strings.TrimSpace(review.User.Login)) + if login == "" { + continue + } + + state := strings.ToUpper(strings.TrimSpace(review.State)) + switch state { + case "CHANGES_REQUESTED", "APPROVED", "DISMISSED": + default: + continue + } + + current, exists := statesByReviewer[login] + if exists && current.reviewID > review.ID { + continue + } + statesByReviewer[login] = reviewerState{ + reviewID: review.ID, + state: state, + } + } + + for _, state := range statesByReviewer { + if state.state == "CHANGES_REQUESTED" { + return true + } + } + return false +} diff --git a/coderd/externalauth/gitprovider/github_test.go b/coderd/externalauth/gitprovider/github_test.go new file mode 100644 index 0000000000..346e7066b0 --- /dev/null +++ b/coderd/externalauth/gitprovider/github_test.go @@ -0,0 +1,994 @@ +package gitprovider_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" + "github.com/coder/quartz" +) + +func TestGitHubParseRepositoryOrigin(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + + tests := []struct { + name string + raw string + expectOK bool + expectOwner string + expectRepo string + expectNormalized string + }{ + { + name: "HTTPS URL", + raw: "https://github.com/coder/coder", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "HTTPS URL with .git", + raw: "https://github.com/coder/coder.git", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "HTTPS URL with trailing slash", + raw: "https://github.com/coder/coder/", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "SSH URL", + raw: "git@github.com:coder/coder.git", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "SSH URL without .git", + raw: "git@github.com:coder/coder", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "SSH URL with ssh:// prefix", + raw: "ssh://git@github.com/coder/coder.git", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNormalized: "https://github.com/coder/coder", + }, + { + name: "GitLab URL does not match", + raw: "https://gitlab.com/coder/coder", + expectOK: false, + }, + { + name: "Empty string", + raw: "", + expectOK: false, + }, + { + name: "Not a URL", + raw: "not-a-url", + expectOK: false, + }, + { + name: "Hyphenated owner and repo", + raw: "https://github.com/my-org/my-repo.git", + expectOK: true, + expectOwner: "my-org", + expectRepo: "my-repo", + expectNormalized: "https://github.com/my-org/my-repo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + owner, repo, normalized, ok := gp.ParseRepositoryOrigin(tt.raw) + assert.Equal(t, tt.expectOK, ok) + if tt.expectOK { + assert.Equal(t, tt.expectOwner, owner) + assert.Equal(t, tt.expectRepo, repo) + assert.Equal(t, tt.expectNormalized, normalized) + } + }) + } +} + +func TestGitHubParsePullRequestURL(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + + tests := []struct { + name string + raw string + expectOK bool + expectOwner string + expectRepo string + expectNumber int + }{ + { + name: "Standard PR URL", + raw: "https://github.com/coder/coder/pull/123", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNumber: 123, + }, + { + name: "PR URL with query string", + raw: "https://github.com/coder/coder/pull/456?diff=split", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNumber: 456, + }, + { + name: "PR URL with fragment", + raw: "https://github.com/coder/coder/pull/789#discussion", + expectOK: true, + expectOwner: "coder", + expectRepo: "coder", + expectNumber: 789, + }, + { + name: "Not a PR URL", + raw: "https://github.com/coder/coder", + expectOK: false, + }, + { + name: "Issue URL (not PR)", + raw: "https://github.com/coder/coder/issues/123", + expectOK: false, + }, + { + name: "GitLab MR URL", + raw: "https://gitlab.com/coder/coder/-/merge_requests/123", + expectOK: false, + }, + { + name: "Empty string", + raw: "", + expectOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ref, ok := gp.ParsePullRequestURL(tt.raw) + assert.Equal(t, tt.expectOK, ok) + if tt.expectOK { + assert.Equal(t, tt.expectOwner, ref.Owner) + assert.Equal(t, tt.expectRepo, ref.Repo) + assert.Equal(t, tt.expectNumber, ref.Number) + } + }) + } +} + +func TestGitHubNormalizePullRequestURL(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + + tests := []struct { + name string + raw string + expected string + }{ + { + name: "Already normalized", + raw: "https://github.com/coder/coder/pull/123", + expected: "https://github.com/coder/coder/pull/123", + }, + { + name: "With trailing punctuation", + raw: "https://github.com/coder/coder/pull/123).", + expected: "https://github.com/coder/coder/pull/123", + }, + { + name: "With query string", + raw: "https://github.com/coder/coder/pull/123?diff=split", + expected: "https://github.com/coder/coder/pull/123", + }, + { + name: "With whitespace", + raw: " https://github.com/coder/coder/pull/123 ", + expected: "https://github.com/coder/coder/pull/123", + }, + { + name: "Not a PR URL", + raw: "https://example.com", + expected: "", + }, + { + name: "Empty string", + raw: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := gp.NormalizePullRequestURL(tt.raw) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGitHubBuildBranchURL(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + + tests := []struct { + name string + owner string + repo string + branch string + expected string + }{ + { + name: "Simple branch", + owner: "coder", + repo: "coder", + branch: "main", + expected: "https://github.com/coder/coder/tree/main", + }, + { + name: "Branch with slash", + owner: "coder", + repo: "coder", + branch: "feat/new-thing", + expected: "https://github.com/coder/coder/tree/feat/new-thing", + }, + { + name: "Empty owner", + owner: "", + repo: "coder", + branch: "main", + expected: "", + }, + { + name: "Empty repo", + owner: "coder", + repo: "", + branch: "main", + expected: "", + }, + { + name: "Empty branch", + owner: "coder", + repo: "coder", + branch: "", + expected: "", + }, + { + name: "Branch with slashes", + owner: "my-org", + repo: "my-repo", + branch: "feat/new-thing", + expected: "https://github.com/my-org/my-repo/tree/feat/new-thing", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := gp.BuildBranchURL(tt.owner, tt.repo, tt.branch) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGitHubBuildPullRequestURL(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + + tests := []struct { + name string + ref gitprovider.PRRef + expected string + }{ + { + name: "Valid PR ref", + ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 123}, + expected: "https://github.com/coder/coder/pull/123", + }, + { + name: "Empty owner", + ref: gitprovider.PRRef{Owner: "", Repo: "coder", Number: 123}, + expected: "", + }, + { + name: "Empty repo", + ref: gitprovider.PRRef{Owner: "coder", Repo: "", Number: 123}, + expected: "", + }, + { + name: "Zero number", + ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 0}, + expected: "", + }, + { + name: "Negative number", + ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: -1}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := gp.BuildPullRequestURL(tt.ref) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGitHubEnterpriseURLs(t *testing.T) { + t.Parallel() + gp := gitprovider.New("github", "https://ghes.corp.com/api/v3", nil) + require.NotNil(t, gp) + + t.Run("ParseRepositoryOrigin HTTPS", func(t *testing.T) { + t.Parallel() + owner, repo, normalized, ok := gp.ParseRepositoryOrigin("https://ghes.corp.com/org/repo.git") + assert.True(t, ok) + assert.Equal(t, "org", owner) + assert.Equal(t, "repo", repo) + assert.Equal(t, "https://ghes.corp.com/org/repo", normalized) + }) + + t.Run("ParseRepositoryOrigin SSH", func(t *testing.T) { + t.Parallel() + owner, repo, normalized, ok := gp.ParseRepositoryOrigin("git@ghes.corp.com:org/repo.git") + assert.True(t, ok) + assert.Equal(t, "org", owner) + assert.Equal(t, "repo", repo) + assert.Equal(t, "https://ghes.corp.com/org/repo", normalized) + }) + + t.Run("ParsePullRequestURL", func(t *testing.T) { + t.Parallel() + ref, ok := gp.ParsePullRequestURL("https://ghes.corp.com/org/repo/pull/42") + assert.True(t, ok) + assert.Equal(t, "org", ref.Owner) + assert.Equal(t, "repo", ref.Repo) + assert.Equal(t, 42, ref.Number) + }) + + t.Run("NormalizePullRequestURL", func(t *testing.T) { + t.Parallel() + result := gp.NormalizePullRequestURL("https://ghes.corp.com/org/repo/pull/42?x=y") + assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result) + }) + + t.Run("BuildBranchURL", func(t *testing.T) { + t.Parallel() + result := gp.BuildBranchURL("org", "repo", "main") + assert.Equal(t, "https://ghes.corp.com/org/repo/tree/main", result) + }) + + t.Run("BuildPullRequestURL", func(t *testing.T) { + t.Parallel() + result := gp.BuildPullRequestURL(gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}) + assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result) + }) + + t.Run("github.com URLs do not match GHE instance", func(t *testing.T) { + t.Parallel() + _, _, _, ok := gp.ParseRepositoryOrigin("https://github.com/coder/coder") + assert.False(t, ok, "github.com HTTPS URL should not match GHE instance") + + _, _, _, ok = gp.ParseRepositoryOrigin("git@github.com:coder/coder.git") + assert.False(t, ok, "github.com SSH URL should not match GHE instance") + + _, ok = gp.ParsePullRequestURL("https://github.com/coder/coder/pull/123") + assert.False(t, ok, "github.com PR URL should not match GHE instance") + }) +} + +func TestNewUnsupportedProvider(t *testing.T) { + t.Parallel() + gp := gitprovider.New("unsupported", "", nil) + assert.Nil(t, gp, "unsupported provider type should return nil") +} + +func TestGitHubRatelimit_403WithResetHeader(t *testing.T) { + t.Parallel() + + resetTime := time.Now().Add(60 * time.Second) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-Ratelimit-Reset", fmt.Sprintf("%d", resetTime.Unix())) + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "API rate limit exceeded"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchPullRequestStatus( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.Error(t, err) + + var rlErr *gitprovider.RateLimitError + require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err) + assert.WithinDuration(t, resetTime.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 2*time.Second) +} + +func TestGitHubRatelimit_429WithRetryAfter(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Retry-After", "120") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"message": "secondary rate limit"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchPullRequestStatus( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.Error(t, err) + + var rlErr *gitprovider.RateLimitError + require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err) + + // Retry-After: 120 means ~120s from now. + expected := time.Now().Add(120 * time.Second) + assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second) +} + +func TestGitHubRatelimit_403NormalError(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Bad credentials"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchPullRequestStatus( + context.Background(), + "bad-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.Error(t, err) + + var rlErr *gitprovider.RateLimitError + assert.False(t, errors.As(err, &rlErr), "error should NOT be *RateLimitError") + assert.Contains(t, err.Error(), "403") +} + +func TestGitHubFetchPullRequestDiff(t *testing.T) { + t.Parallel() + + const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n" + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(smallDiff)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + diff, err := gp.FetchPullRequestDiff( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.NoError(t, err) + assert.Equal(t, smallDiff, diff) + }) + + t.Run("ExactlyMaxSize", func(t *testing.T) { + t.Parallel() + + exactDiff := string(make([]byte, gitprovider.MaxDiffSize)) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(exactDiff)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + diff, err := gp.FetchPullRequestDiff( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.NoError(t, err) + assert.Len(t, diff, gitprovider.MaxDiffSize) + }) + + t.Run("TooLarge", func(t *testing.T) { + t.Parallel() + + oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024)) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(oversizeDiff)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchPullRequestDiff( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge) + }) +} + +func TestFetchPullRequestDiff_Ratelimit(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"message": "rate limit"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchPullRequestDiff( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, + ) + require.Error(t, err) + + var rlErr *gitprovider.RateLimitError + require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err) + expected := time.Now().Add(60 * time.Second) + assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second) +} + +func TestFetchBranchDiff_Ratelimit(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/compare/") { + // Second request: compare endpoint returns 429. + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"message": "rate limit"}`)) + return + } + // First request: repo metadata. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"default_branch":"main"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchBranchDiff( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"}, + ) + require.Error(t, err) + + var rlErr *gitprovider.RateLimitError + require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err) + expected := time.Now().Add(60 * time.Second) + assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second) +} + +func TestFetchPullRequestStatus(t *testing.T) { + t.Parallel() + + type review struct { + ID int64 `json:"id"` + State string `json:"state"` + User struct { + Login string `json:"login"` + } `json:"user"` + } + + makeReview := func(id int64, state, login string) review { + r := review{ID: id, State: state} + r.User.Login = login + return r + } + + tests := []struct { + name string + pullJSON string + reviews []review + expectedState gitprovider.PRState + expectedDraft bool + changesRequested bool + }{ + { + name: "OpenPR/NoReviews", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{}, + expectedState: gitprovider.PRStateOpen, + expectedDraft: false, + changesRequested: false, + }, + { + name: "OpenPR/SingleChangesRequested", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{makeReview(1, "CHANGES_REQUESTED", "alice")}, + expectedState: gitprovider.PRStateOpen, + changesRequested: true, + }, + { + name: "OpenPR/ChangesRequestedThenApproved", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{ + makeReview(1, "CHANGES_REQUESTED", "alice"), + makeReview(2, "APPROVED", "alice"), + }, + expectedState: gitprovider.PRStateOpen, + changesRequested: false, + }, + { + name: "OpenPR/ChangesRequestedThenDismissed", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{ + makeReview(1, "CHANGES_REQUESTED", "alice"), + makeReview(2, "DISMISSED", "alice"), + }, + expectedState: gitprovider.PRStateOpen, + changesRequested: false, + }, + { + name: "OpenPR/MultipleReviewersMixed", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{ + makeReview(1, "APPROVED", "alice"), + makeReview(2, "CHANGES_REQUESTED", "bob"), + }, + expectedState: gitprovider.PRStateOpen, + changesRequested: true, + }, + { + name: "OpenPR/CommentedDoesNotAffect", + pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{ + makeReview(1, "COMMENTED", "alice"), + }, + expectedState: gitprovider.PRStateOpen, + changesRequested: false, + }, + { + name: "MergedPR", + pullJSON: `{"state":"closed","merged":true,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{}, + expectedState: gitprovider.PRStateMerged, + changesRequested: false, + }, + { + name: "DraftPR", + pullJSON: `{"state":"open","merged":false,"draft":true,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`, + reviews: []review{}, + expectedState: gitprovider.PRStateOpen, + expectedDraft: true, + changesRequested: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + reviewsJSON, err := json.Marshal(tc.reviews) + require.NoError(t, err) + + mux := http.NewServeMux() + mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1/reviews", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(reviewsJSON) + }) + mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tc.pullJSON)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + before := time.Now().UTC() + status, err := gp.FetchPullRequestStatus( + context.Background(), + "test-token", + gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, + ) + require.NoError(t, err) + + assert.Equal(t, tc.expectedState, status.State) + assert.Equal(t, tc.expectedDraft, status.Draft) + assert.Equal(t, tc.changesRequested, status.ChangesRequested) + assert.Equal(t, "abc123", status.HeadSHA) + assert.Equal(t, int32(10), status.DiffStats.Additions) + assert.Equal(t, int32(5), status.DiffStats.Deletions) + assert.Equal(t, int32(3), status.DiffStats.ChangedFiles) + assert.False(t, status.FetchedAt.IsZero()) + assert.True(t, !status.FetchedAt.Before(before), "FetchedAt should be >= test start time") + }) + } +} + +func TestResolveBranchPullRequest(t *testing.T) { + t.Parallel() + + t.Run("Found", func(t *testing.T) { + t.Parallel() + + var srvURL string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify query parameters. + assert.Equal(t, "open", r.URL.Query().Get("state")) + assert.Equal(t, "owner:feat", r.URL.Query().Get("head")) + w.Header().Set("Content-Type", "application/json") + // Use the test server's URL so ParsePullRequestURL + // matches the provider's derived web host. + htmlURL := fmt.Sprintf("https://%s/owner/repo/pull/42", + strings.TrimPrefix(strings.TrimPrefix(srvURL, "http://"), "https://")) + _, _ = w.Write([]byte(fmt.Sprintf(`[{"html_url":%q,"number":42}]`, htmlURL))) + })) + defer srv.Close() + srvURL = srv.URL + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + prRef, err := gp.ResolveBranchPullRequest( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"}, + ) + require.NoError(t, err) + require.NotNil(t, prRef) + assert.Equal(t, "owner", prRef.Owner) + assert.Equal(t, "repo", prRef.Repo) + assert.Equal(t, 42, prRef.Number) + }) + + t.Run("NoneOpen", func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[]`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + prRef, err := gp.ResolveBranchPullRequest( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"}, + ) + require.NoError(t, err) + assert.Nil(t, prRef) + }) + + t.Run("InvalidHTMLURL", func(t *testing.T) { + t.Parallel() + + // If html_url can't be parsed as a PR URL, ResolveBranchPullRequest + // returns nil, nil. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[{"html_url":"not-a-valid-url","number":42}]`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + prRef, err := gp.ResolveBranchPullRequest( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"}, + ) + require.NoError(t, err) + assert.Nil(t, prRef) + }) +} + +func TestFetchBranchDiff(t *testing.T) { + t.Parallel() + + const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n" + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/compare/") { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(smallDiff)) + return + } + // Repo metadata. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"default_branch":"main"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + diff, err := gp.FetchBranchDiff( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"}, + ) + require.NoError(t, err) + assert.Equal(t, smallDiff, diff) + }) + + t.Run("EmptyDefaultBranch", func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"default_branch":""}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchBranchDiff( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "default branch is empty") + }) + + t.Run("DiffTooLarge", func(t *testing.T) { + t.Parallel() + + oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024)) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/compare/") { + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(oversizeDiff)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"default_branch":"main"}`)) + })) + defer srv.Close() + + gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client()) + require.NotNil(t, gp) + + _, err := gp.FetchBranchDiff( + context.Background(), + "test-token", + gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"}, + ) + assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge) + }) +} + +func TestEscapePathPreserveSlashes(t *testing.T) { + t.Parallel() + // The function is unexported, so test it indirectly via BuildBranchURL. + // A branch with a space in a segment should be escaped, but slashes preserved. + gp := gitprovider.New("github", "", nil) + require.NotNil(t, gp) + got := gp.BuildBranchURL("owner", "repo", "feat/my thing") + assert.Equal(t, "https://github.com/owner/repo/tree/feat/my%20thing", got) +} + +func TestParseRetryAfter(t *testing.T) { + t.Parallel() + + clk := quartz.NewMock(t) + clk.Set(time.Now()) + + t.Run("RetryAfterSeconds", func(t *testing.T) { + t.Parallel() + h := http.Header{} + h.Set("Retry-After", "120") + d := gitprovider.ParseRetryAfter(h, clk) + assert.Equal(t, 120*time.Second, d) + }) + + t.Run("XRatelimitReset", func(t *testing.T) { + t.Parallel() + future := clk.Now().Add(90 * time.Second) + t.Logf("now: %d future: %d", clk.Now().Unix(), future.Unix()) + h := http.Header{} + h.Set("X-Ratelimit-Reset", strconv.FormatInt(future.Unix(), 10)) + d := gitprovider.ParseRetryAfter(h, clk) + assert.WithinDuration(t, future, clk.Now().Add(d), time.Second) + }) + + t.Run("NoHeaders", func(t *testing.T) { + t.Parallel() + h := http.Header{} + d := gitprovider.ParseRetryAfter(h, clk) + assert.Equal(t, time.Duration(0), d) + }) + + t.Run("InvalidValue", func(t *testing.T) { + t.Parallel() + h := http.Header{} + h.Set("Retry-After", "not-a-number") + d := gitprovider.ParseRetryAfter(h, clk) + assert.Equal(t, time.Duration(0), d) + }) + + t.Run("RetryAfterTakesPrecedence", func(t *testing.T) { + t.Parallel() + h := http.Header{} + h.Set("Retry-After", "60") + h.Set("X-Ratelimit-Reset", strconv.FormatInt( + clk.Now().Unix()+120, 10, + )) + d := gitprovider.ParseRetryAfter(h, clk) + assert.Equal(t, 60*time.Second, d) + }) +} diff --git a/coderd/externalauth/gitprovider/gitprovider.go b/coderd/externalauth/gitprovider/gitprovider.go new file mode 100644 index 0000000000..7806b0bf48 --- /dev/null +++ b/coderd/externalauth/gitprovider/gitprovider.go @@ -0,0 +1,179 @@ +package gitprovider + +import ( + "context" + "fmt" + "net/http" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/quartz" +) + +// providerOptions holds optional configuration for provider +// construction. +type providerOptions struct { + clock quartz.Clock +} + +// Option configures optional behavior for a Provider. +type Option func(*providerOptions) + +// WithClock sets the clock used by the provider. Defaults to +// quartz.NewReal() if not provided. +func WithClock(c quartz.Clock) Option { + return func(o *providerOptions) { + o.clock = c + } +} + +// PRState is the normalized state of a pull/merge request across +// all providers. +type PRState string + +const ( + PRStateOpen PRState = "open" + PRStateClosed PRState = "closed" + PRStateMerged PRState = "merged" +) + +// PRRef identifies a pull request on any provider. +type PRRef struct { + // Owner is the repository owner / project / workspace. + Owner string + // Repo is the repository name or slug. + Repo string + // Number is the PR number / IID / index. + Number int +} + +// BranchRef identifies a branch in a repository, used for +// branch-to-PR resolution. +type BranchRef struct { + Owner string + Repo string + Branch string +} + +// DiffStats summarizes the size of a PR's changes. +type DiffStats struct { + Additions int32 + Deletions int32 + ChangedFiles int32 +} + +// PRStatus is the complete status of a pull/merge request. +// This is the universal return type that all providers populate. +type PRStatus struct { + // State is the PR's lifecycle state. + State PRState + // Draft indicates the PR is marked as draft/WIP. + Draft bool + // HeadSHA is the SHA of the head commit. + HeadSHA string + // DiffStats summarizes additions/deletions/files changed. + DiffStats DiffStats + // ChangesRequested is a convenience boolean: true if any + // reviewer's current state is "changes_requested". + ChangesRequested bool + // FetchedAt is when this status was fetched. + FetchedAt time.Time +} + +// MaxDiffSize is the maximum number of bytes read from a diff +// response. Diffs exceeding this limit are rejected with +// ErrDiffTooLarge. +const MaxDiffSize = 4 << 20 // 4 MiB + +// ErrDiffTooLarge is returned when a diff exceeds MaxDiffSize. +var ErrDiffTooLarge = xerrors.Errorf("diff exceeds maximum size of %d bytes", MaxDiffSize) + +// Provider defines the interface that all Git hosting providers +// implement. Each method is designed to minimize API round-trips +// for the specific provider. +type Provider interface { + // FetchPullRequestStatus retrieves the complete status of a + // pull request in the minimum number of API calls for this + // provider. + FetchPullRequestStatus(ctx context.Context, token string, ref PRRef) (*PRStatus, error) + + // ResolveBranchPullRequest finds the open PR (if any) for + // the given branch. Returns nil, nil if no open PR exists. + ResolveBranchPullRequest(ctx context.Context, token string, ref BranchRef) (*PRRef, error) + + // FetchPullRequestDiff returns the raw unified diff for a + // pull request. This uses the PR's actual base branch (which + // may differ from the repo default branch, e.g. a PR + // targeting "staging" instead of "main"), so it matches what + // the provider shows on the PR's "Files changed" tab. + // Returns ErrDiffTooLarge if the diff exceeds MaxDiffSize. + FetchPullRequestDiff(ctx context.Context, token string, ref PRRef) (string, error) + + // FetchBranchDiff returns the diff of a branch compared + // against the repository's default branch. This is the + // fallback when no pull request exists yet (e.g. the agent + // pushed a branch but hasn't opened a PR). Returns + // ErrDiffTooLarge if the diff exceeds MaxDiffSize. + FetchBranchDiff(ctx context.Context, token string, ref BranchRef) (string, error) + + // ParseRepositoryOrigin parses a remote origin URL (HTTPS + // or SSH) into owner and repo components, returning the + // normalized HTTPS URL. Returns false if the URL does not + // match this provider. + ParseRepositoryOrigin(raw string) (owner, repo, normalizedOrigin string, ok bool) + + // ParsePullRequestURL parses a pull request URL into a + // PRRef. Returns false if the URL does not match this + // provider. + ParsePullRequestURL(raw string) (PRRef, bool) + + // NormalizePullRequestURL normalizes a pull request URL, + // stripping trailing punctuation, query strings, and + // fragments. Returns empty string if the URL does not + // match this provider. + NormalizePullRequestURL(raw string) string + + // BuildBranchURL constructs a URL to view a branch on + // the provider's web UI. + BuildBranchURL(owner, repo, branch string) string + + // BuildRepositoryURL constructs a URL to view a repository + // on the provider's web UI. + BuildRepositoryURL(owner, repo string) string + + // BuildPullRequestURL constructs a URL to view a pull + // request on the provider's web UI. + BuildPullRequestURL(ref PRRef) string +} + +// New creates a Provider for the given provider type and API base +// URL. Returns nil if the provider type is not a supported git +// provider. +func New(providerType string, apiBaseURL string, httpClient *http.Client, opts ...Option) Provider { + o := providerOptions{} + for _, opt := range opts { + opt(&o) + } + if o.clock == nil { + o.clock = quartz.NewReal() + } + + switch providerType { + case "github": + return newGitHub(apiBaseURL, httpClient, o.clock) + default: + // Other providers (gitlab, bitbucket-cloud, etc.) will be + // added here as they are implemented. + return nil + } +} + +// RateLimitError indicates the git provider's API rate limit was hit. +type RateLimitError struct { + RetryAfter time.Time +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limited until %s", e.RetryAfter.Format(time.RFC3339)) +} diff --git a/coderd/gitsync/gitsync.go b/coderd/gitsync/gitsync.go new file mode 100644 index 0000000000..4f702fd8e7 --- /dev/null +++ b/coderd/gitsync/gitsync.go @@ -0,0 +1,230 @@ +package gitsync + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" + "github.com/coder/quartz" +) + +const ( + // DiffStatusTTL is how long a successfully refreshed + // diff status remains fresh before becoming stale again. + DiffStatusTTL = 120 * time.Second +) + +// ProviderResolver maps a git remote origin to the gitprovider +// that handles it. Returns nil if no provider matches. +type ProviderResolver func(origin string) gitprovider.Provider + +var ErrNoTokenAvailable error = errors.New("no token available") + +// TokenResolver obtains the user's git access token for a given +// remote origin. Should return nil if no token is available, in +// which case ErrNoTokenAvailable will be returned. +type TokenResolver func( + ctx context.Context, + userID uuid.UUID, + origin string, +) (*string, error) + +// Refresher contains the stateless business logic for fetching +// fresh PR data from a git provider given a stale +// database.ChatDiffStatus row. +type Refresher struct { + providers ProviderResolver + tokens TokenResolver + logger slog.Logger + clock quartz.Clock +} + +// NewRefresher creates a Refresher with the given dependency +// functions. +func NewRefresher( + providers ProviderResolver, + tokens TokenResolver, + logger slog.Logger, + clock quartz.Clock, +) *Refresher { + return &Refresher{ + providers: providers, + tokens: tokens, + logger: logger, + clock: clock, + } +} + +// RefreshRequest pairs a stale row with the chat owner who +// holds the git token needed for API calls. +type RefreshRequest struct { + Row database.ChatDiffStatus + OwnerID uuid.UUID +} + +// RefreshResult is the outcome for a single row. +// - Params != nil, Error == nil → success, caller should upsert. +// - Params == nil, Error == nil → no PR yet, caller should skip. +// - Params == nil, Error != nil → row-level failure. +type RefreshResult struct { + Request RefreshRequest + Params *database.UpsertChatDiffStatusParams + Error error +} + +// groupKey identifies a unique (owner, origin) pair so that +// provider and token resolution happen once per group. +type groupKey struct { + ownerID uuid.UUID + origin string +} + +// Refresh fetches fresh PR data for a batch of stale rows. +// Rows are grouped internally by (ownerID, origin) so that +// provider and token resolution happen once per group. A +// top-level error is returned only when the entire batch +// fails catastrophically. Per-row outcomes are in the +// returned RefreshResult slice (one per input request, same +// order). +func (r *Refresher) Refresh( + ctx context.Context, + requests []RefreshRequest, +) ([]RefreshResult, error) { + results := make([]RefreshResult, len(requests)) + for i, req := range requests { + results[i].Request = req + } + + // Group request indices by (ownerID, origin). + groups := make(map[groupKey][]int) + for i, req := range requests { + key := groupKey{ + ownerID: req.OwnerID, + origin: req.Row.GitRemoteOrigin, + } + groups[key] = append(groups[key], i) + } + + for key, indices := range groups { + provider := r.providers(key.origin) + if provider == nil { + err := xerrors.Errorf("no provider for origin %q", key.origin) + for _, i := range indices { + results[i].Error = err + } + continue + } + + token, err := r.tokens(ctx, key.ownerID, key.origin) + if err != nil { + err = xerrors.Errorf("resolve token: %w", err) + } else if token == nil || len(*token) == 0 { + err = ErrNoTokenAvailable + } + if err != nil { + for _, i := range indices { + results[i].Error = err + } + continue + } + // This is technically unnecessary but kept here as a future molly-guard. + if token == nil { + continue + } + + for i, idx := range indices { + req := requests[idx] + params, err := r.refreshOne(ctx, provider, *token, req.Row) + results[idx] = RefreshResult{Request: req, Params: params, Error: err} + + // If rate-limited, skip remaining rows in this group. + var rlErr *gitprovider.RateLimitError + if errors.As(err, &rlErr) { + for _, remaining := range indices[i+1:] { + results[remaining] = RefreshResult{ + Request: requests[remaining], + Error: fmt.Errorf("skipped: %w", rlErr), + } + } + break + } + } + } + + return results, nil +} + +// refreshOne processes a single row using an already-resolved +// provider and token. This is the old Refresh logic, unchanged. +func (r *Refresher) refreshOne( + ctx context.Context, + provider gitprovider.Provider, + token string, + row database.ChatDiffStatus, +) (*database.UpsertChatDiffStatusParams, error) { + var ref gitprovider.PRRef + var prURL string + + if row.Url.Valid && row.Url.String != "" { + // Row already has a PR URL — parse it directly. + parsed, ok := provider.ParsePullRequestURL(row.Url.String) + if !ok { + return nil, xerrors.Errorf("parse pull request URL %q", row.Url.String) + } + ref = parsed + prURL = row.Url.String + } else { + // No PR URL — resolve owner/repo from the remote origin, + // then look up the open PR for this branch. + owner, repo, _, ok := provider.ParseRepositoryOrigin(row.GitRemoteOrigin) + if !ok { + return nil, xerrors.Errorf("parse repository origin %q", row.GitRemoteOrigin) + } + + resolved, err := provider.ResolveBranchPullRequest(ctx, token, gitprovider.BranchRef{ + Owner: owner, + Repo: repo, + Branch: row.GitBranch, + }) + if err != nil { + return nil, xerrors.Errorf("resolve branch pull request: %w", err) + } + if resolved == nil { + // No PR exists yet for this branch. + return nil, nil + } + ref = *resolved + prURL = provider.BuildPullRequestURL(ref) + } + + status, err := provider.FetchPullRequestStatus(ctx, token, ref) + if err != nil { + return nil, xerrors.Errorf("fetch pull request status: %w", err) + } + + now := r.clock.Now().UTC() + params := &database.UpsertChatDiffStatusParams{ + ChatID: row.ChatID, + Url: sql.NullString{String: prURL, Valid: prURL != ""}, + PullRequestState: sql.NullString{ + String: string(status.State), + Valid: status.State != "", + }, + ChangesRequested: status.ChangesRequested, + Additions: status.DiffStats.Additions, + Deletions: status.DiffStats.Deletions, + ChangedFiles: status.DiffStats.ChangedFiles, + RefreshedAt: now, + StaleAt: now.Add(DiffStatusTTL), + } + + return params, nil +} diff --git a/coderd/gitsync/gitsync_test.go b/coderd/gitsync/gitsync_test.go new file mode 100644 index 0000000000..6d720131ff --- /dev/null +++ b/coderd/gitsync/gitsync_test.go @@ -0,0 +1,775 @@ +package gitsync_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" + "github.com/coder/coder/v2/coderd/gitsync" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/quartz" +) + +// mockProvider implements gitprovider.Provider with function fields +// so each test can wire only the methods it needs. Any method left +// nil panics with "unexpected call". +type mockProvider struct { + fetchPullRequestStatus func(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) + resolveBranchPR func(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) + fetchPullRequestDiff func(ctx context.Context, token string, ref gitprovider.PRRef) (string, error) + fetchBranchDiff func(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error) + parseRepositoryOrigin func(raw string) (string, string, string, bool) + parsePullRequestURL func(raw string) (gitprovider.PRRef, bool) + normalizePullRequestURL func(raw string) string + buildBranchURL func(owner, repo, branch string) string + buildRepositoryURL func(owner, repo string) string + buildPullRequestURL func(ref gitprovider.PRRef) string +} + +func (m *mockProvider) FetchPullRequestStatus(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) { + if m.fetchPullRequestStatus == nil { + panic("unexpected call to FetchPullRequestStatus") + } + return m.fetchPullRequestStatus(ctx, token, ref) +} + +func (m *mockProvider) ResolveBranchPullRequest(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) { + if m.resolveBranchPR == nil { + panic("unexpected call to ResolveBranchPullRequest") + } + return m.resolveBranchPR(ctx, token, ref) +} + +func (m *mockProvider) FetchPullRequestDiff(ctx context.Context, token string, ref gitprovider.PRRef) (string, error) { + if m.fetchPullRequestDiff == nil { + panic("unexpected call to FetchPullRequestDiff") + } + return m.fetchPullRequestDiff(ctx, token, ref) +} + +func (m *mockProvider) FetchBranchDiff(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error) { + if m.fetchBranchDiff == nil { + panic("unexpected call to FetchBranchDiff") + } + return m.fetchBranchDiff(ctx, token, ref) +} + +func (m *mockProvider) ParseRepositoryOrigin(raw string) (string, string, string, bool) { + if m.parseRepositoryOrigin == nil { + panic("unexpected call to ParseRepositoryOrigin") + } + return m.parseRepositoryOrigin(raw) +} + +func (m *mockProvider) ParsePullRequestURL(raw string) (gitprovider.PRRef, bool) { + if m.parsePullRequestURL == nil { + panic("unexpected call to ParsePullRequestURL") + } + return m.parsePullRequestURL(raw) +} + +func (m *mockProvider) NormalizePullRequestURL(raw string) string { + if m.normalizePullRequestURL == nil { + panic("unexpected call to NormalizePullRequestURL") + } + return m.normalizePullRequestURL(raw) +} + +func (m *mockProvider) BuildBranchURL(owner, repo, branch string) string { + if m.buildBranchURL == nil { + panic("unexpected call to BuildBranchURL") + } + return m.buildBranchURL(owner, repo, branch) +} + +func (m *mockProvider) BuildRepositoryURL(owner, repo string) string { + if m.buildRepositoryURL == nil { + panic("unexpected call to BuildRepositoryURL") + } + return m.buildRepositoryURL(owner, repo) +} + +func (m *mockProvider) BuildPullRequestURL(ref gitprovider.PRRef) string { + if m.buildPullRequestURL == nil { + panic("unexpected call to BuildPullRequestURL") + } + return m.buildPullRequestURL(ref) +} + +func TestRefresher_WithPRURL(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true + }, + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return &gitprovider.PRStatus{ + State: gitprovider.PRStateOpen, + DiffStats: gitprovider.DiffStats{ + Additions: 10, + Deletions: 5, + ChangedFiles: 3, + }, + }, nil + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + chatID := uuid.New() + row := database.ChatDiffStatus{ + ChatID: chatID, + Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + require.NoError(t, res.Error) + require.NotNil(t, res.Params) + + assert.Equal(t, chatID, res.Params.ChatID) + assert.Equal(t, "open", res.Params.PullRequestState.String) + assert.True(t, res.Params.PullRequestState.Valid) + assert.Equal(t, int32(10), res.Params.Additions) + assert.Equal(t, int32(5), res.Params.Deletions) + assert.Equal(t, int32(3), res.Params.ChangedFiles) + + // StaleAt should be ~120s after RefreshedAt. + diff := res.Params.StaleAt.Sub(res.Params.RefreshedAt) + assert.InDelta(t, 120, diff.Seconds(), 5) +} + +func TestRefresher_BranchResolvesToPR(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parseRepositoryOrigin: func(_ string) (string, string, string, bool) { + return "org", "repo", "https://github.com/org/repo", true + }, + resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) { + return &gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 7}, nil + }, + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil + }, + buildPullRequestURL: func(_ gitprovider.PRRef) string { + return "https://github.com/org/repo/pull/7" + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + require.NoError(t, res.Error) + require.NotNil(t, res.Params) + + assert.Contains(t, res.Params.Url.String, "pull/7") + assert.True(t, res.Params.Url.Valid) + assert.Equal(t, "open", res.Params.PullRequestState.String) +} + +func TestRefresher_BranchNoPRYet(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parseRepositoryOrigin: func(_ string) (string, string, string, bool) { + return "org", "repo", "https://github.com/org/repo", true + }, + resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) { + return nil, nil + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.NoError(t, res.Error) + assert.Nil(t, res.Params) +} + +func TestRefresher_NoProviderForOrigin(t *testing.T) { + t.Parallel() + + providers := func(_ string) gitprovider.Provider { return nil } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://example.com/pr/1", Valid: true}, + GitRemoteOrigin: "https://example.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.Nil(t, res.Params) + require.Error(t, res.Error) + assert.Contains(t, res.Error.Error(), "no provider") +} + +func TestRefresher_TokenResolutionFails(t *testing.T) { + t.Parallel() + + var fetchCalled atomic.Bool + mp := &mockProvider{ + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + fetchCalled.Store(true) + return nil, errors.New("should not be called") + }, + parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return nil, errors.New("token lookup failed") + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.Nil(t, res.Params) + require.Error(t, res.Error) + assert.False(t, fetchCalled.Load(), "FetchPullRequestStatus should not be called when token resolution fails") +} + +func TestRefresher_EmptyToken(t *testing.T) { + t.Parallel() + + mp := &mockProvider{} + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref(""), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.Nil(t, res.Params) + require.ErrorIs(t, res.Error, gitsync.ErrNoTokenAvailable) +} + +func TestRefresher_ProviderFetchFails(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true + }, + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return nil, errors.New("api error") + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.Nil(t, res.Params) + require.Error(t, res.Error) + assert.Contains(t, res.Error.Error(), "api error") +} + +func TestRefresher_PRURLParseFailure(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{}, false + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + row := database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/not-a-pr", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + assert.Nil(t, res.Params) + require.Error(t, res.Error) +} + +func TestRefresher_BatchGroupsByOwnerAndOrigin(t *testing.T) { + t.Parallel() + + mp := &mockProvider{ + parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true + }, + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + + var tokenCalls atomic.Int32 + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + tokenCalls.Add(1) + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + ownerID := uuid.New() + originA := "https://github.com/org/repo" + originB := "https://gitlab.com/org/repo" + + requests := []gitsync.RefreshRequest{ + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: originA, + GitBranch: "feature-1", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: originA, + GitBranch: "feature-2", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://gitlab.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: originB, + GitBranch: "feature-3", + }, + OwnerID: ownerID, + }, + } + + results, err := r.Refresh(context.Background(), requests) + require.NoError(t, err) + require.Len(t, results, 3) + + for i, res := range results { + require.NoError(t, res.Error, "result[%d] should not have an error", i) + require.NotNil(t, res.Params, "result[%d] should have params", i) + } + + // Two distinct (ownerID, origin) groups → exactly 2 token + // resolution calls. + assert.Equal(t, int32(2), tokenCalls.Load(), + "TokenResolver should be called once per (owner, origin) group") +} + +func TestRefresher_UsesInjectedClock(t *testing.T) { + t.Parallel() + + mClock := quartz.NewMock(t) + fixedTime := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) + mClock.Set(fixedTime) + + mp := &mockProvider{ + parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true + }, + fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return &gitprovider.PRStatus{ + State: gitprovider.PRStateOpen, + DiffStats: gitprovider.DiffStats{ + Additions: 10, + Deletions: 5, + ChangedFiles: 3, + }, + }, nil + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), mClock) + + chatID := uuid.New() + row := database.ChatDiffStatus{ + ChatID: chatID, + Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature", + } + + ownerID := uuid.New() + results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{ + {Row: row, OwnerID: ownerID}, + }) + require.NoError(t, err) + require.Len(t, results, 1) + res := results[0] + + require.NoError(t, res.Error) + require.NotNil(t, res.Params) + + // The mock clock is deterministic, so times must be exact. + assert.Equal(t, fixedTime, res.Params.RefreshedAt) + assert.Equal(t, fixedTime.Add(gitsync.DiffStatusTTL), res.Params.StaleAt) +} + +func TestRefresher_RateLimitSkipsRemainingInGroup(t *testing.T) { + t.Parallel() + + var callCount atomic.Int32 + + mp := &mockProvider{ + parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) { + var num int + switch { + case strings.HasSuffix(raw, "/pull/1"): + num = 1 + case strings.HasSuffix(raw, "/pull/2"): + num = 2 + case strings.HasSuffix(raw, "/pull/3"): + num = 3 + default: + return gitprovider.PRRef{}, false + } + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true + }, + fetchPullRequestStatus: func(_ context.Context, _ string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) { + call := callCount.Add(1) + switch call { + case 1: + // First call succeeds. + return &gitprovider.PRStatus{ + State: gitprovider.PRStateOpen, + DiffStats: gitprovider.DiffStats{ + Additions: 5, + Deletions: 2, + ChangedFiles: 1, + }, + }, nil + case 2: + // Second call hits rate limit. + return nil, &gitprovider.RateLimitError{ + RetryAfter: time.Now().Add(60 * time.Second), + } + default: + // Third call should never happen. + t.Fatal("FetchPullRequestStatus called more than 2 times") + return nil, nil + } + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) { + return ptr.Ref("test-token"), nil + } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + ownerID := uuid.New() + origin := "https://github.com/org/repo" + + requests := []gitsync.RefreshRequest{ + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: origin, + GitBranch: "feat-1", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true}, + GitRemoteOrigin: origin, + GitBranch: "feat-2", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/3", Valid: true}, + GitRemoteOrigin: origin, + GitBranch: "feat-3", + }, + OwnerID: ownerID, + }, + } + + results, err := r.Refresh(context.Background(), requests) + require.NoError(t, err) + require.Len(t, results, 3) + + // Row 0: success. + assert.NoError(t, results[0].Error) + assert.NotNil(t, results[0].Params) + + // Row 1: rate-limited. + require.Error(t, results[1].Error) + var rlErr1 *gitprovider.RateLimitError + assert.True(t, errors.As(results[1].Error, &rlErr1), + "result[1] error should be *RateLimitError") + + // Row 2: skipped due to rate limit. + require.Error(t, results[2].Error) + var rlErr2 *gitprovider.RateLimitError + assert.True(t, errors.As(results[2].Error, &rlErr2), + "result[2] error should wrap *RateLimitError") + assert.Contains(t, results[2].Error.Error(), "skipped") + + // Provider should have been called exactly twice. + assert.Equal(t, int32(2), callCount.Load(), + "FetchPullRequestStatus should be called exactly 2 times") +} + +func TestRefresher_CorrectTokenPerOrigin(t *testing.T) { + t.Parallel() + + var tokenCalls atomic.Int32 + tokens := func(_ context.Context, _ uuid.UUID, origin string) (*string, error) { + tokenCalls.Add(1) + switch { + case strings.Contains(origin, "github.com"): + return ptr.Ref("gh-public-token"), nil + case strings.Contains(origin, "ghes.corp.com"): + return ptr.Ref("ghe-private-token"), nil + default: + return nil, fmt.Errorf("unexpected origin: %s", origin) + } + } + + // Track which token each FetchPullRequestStatus call received, + // keyed by chat ID. We pass the chat ID through the PRRef.Number + // field (unique per request) so FetchPullRequestStatus can + // identify which row it's processing. + var mu sync.Mutex + tokensByPR := make(map[int]string) + + mp := &mockProvider{ + parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) { + // Extract a unique PR number from the URL to identify + // each row inside FetchPullRequestStatus. + var num int + switch { + case strings.HasSuffix(raw, "/pull/1"): + num = 1 + case strings.HasSuffix(raw, "/pull/2"): + num = 2 + case strings.HasSuffix(raw, "/pull/10"): + num = 10 + default: + return gitprovider.PRRef{}, false + } + return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true + }, + fetchPullRequestStatus: func(_ context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) { + mu.Lock() + tokensByPR[ref.Number] = token + mu.Unlock() + return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil + }, + } + + providers := func(_ string) gitprovider.Provider { return mp } + + r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal()) + + ownerID := uuid.New() + + requests := []gitsync.RefreshRequest{ + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature-1", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true}, + GitRemoteOrigin: "https://github.com/org/repo", + GitBranch: "feature-2", + }, + OwnerID: ownerID, + }, + { + Row: database.ChatDiffStatus{ + ChatID: uuid.New(), + Url: sql.NullString{String: "https://ghes.corp.com/org/repo/pull/10", Valid: true}, + GitRemoteOrigin: "https://ghes.corp.com/org/repo", + GitBranch: "feature-3", + }, + OwnerID: ownerID, + }, + } + + results, err := r.Refresh(context.Background(), requests) + require.NoError(t, err) + require.Len(t, results, 3) + + for i, res := range results { + require.NoError(t, res.Error, "result[%d] should not have an error", i) + require.NotNil(t, res.Params, "result[%d] should have params", i) + } + + // github.com rows (PR #1 and #2) should use the public token. + assert.Equal(t, "gh-public-token", tokensByPR[1], + "github.com PR #1 should use gh-public-token") + assert.Equal(t, "gh-public-token", tokensByPR[2], + "github.com PR #2 should use gh-public-token") + + // ghes.corp.com row (PR #10) should use the GHE token. + assert.Equal(t, "ghe-private-token", tokensByPR[10], + "ghes.corp.com PR #10 should use ghe-private-token") + + // Token resolution should be called exactly twice — once per + // (owner, origin) group. + assert.Equal(t, int32(2), tokenCalls.Load(), + "TokenResolver should be called once per (owner, origin) group") +} diff --git a/coderd/gitsync/worker.go b/coderd/gitsync/worker.go new file mode 100644 index 0000000000..222b8dd074 --- /dev/null +++ b/coderd/gitsync/worker.go @@ -0,0 +1,255 @@ +package gitsync + +import ( + "context" + "database/sql" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/quartz" +) + +const ( + // defaultBatchSize is the maximum number of stale rows fetched + // per tick. + defaultBatchSize int32 = 50 + + // defaultInterval is the polling interval between ticks. + defaultInterval = 10 * time.Second +) + +// Store is the narrow DB interface the Worker needs. +type Store interface { + AcquireStaleChatDiffStatuses( + ctx context.Context, limitVal int32, + ) ([]database.AcquireStaleChatDiffStatusesRow, error) + BackoffChatDiffStatus( + ctx context.Context, arg database.BackoffChatDiffStatusParams, + ) error + UpsertChatDiffStatus( + ctx context.Context, arg database.UpsertChatDiffStatusParams, + ) (database.ChatDiffStatus, error) + UpsertChatDiffStatusReference( + ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams, + ) (database.ChatDiffStatus, error) + GetChatsByOwnerID( + ctx context.Context, arg database.GetChatsByOwnerIDParams, + ) ([]database.Chat, error) +} + +// EventPublisher notifies the frontend of diff status changes. +type PublishDiffStatusChangeFunc func(ctx context.Context, chatID uuid.UUID) error + +// Worker is a background loop that periodically refreshes stale +// chat diff statuses by delegating to a Refresher. +type Worker struct { + store Store + refresher *Refresher + publishDiffStatusChangeFn PublishDiffStatusChangeFunc + clock quartz.Clock + logger slog.Logger + batchSize int32 + interval time.Duration + done chan struct{} +} + +// NewWorker creates a Worker with default batch size and interval. +func NewWorker( + store Store, + refresher *Refresher, + publisher PublishDiffStatusChangeFunc, + clock quartz.Clock, + logger slog.Logger, +) *Worker { + return &Worker{ + store: store, + refresher: refresher, + publishDiffStatusChangeFn: publisher, + clock: clock, + logger: logger, + batchSize: defaultBatchSize, + interval: defaultInterval, + done: make(chan struct{}), + } +} + +// Start launches the background loop. It blocks until ctx is +// cancelled, then closes w.done. +func (w *Worker) Start(ctx context.Context) { + defer close(w.done) + + ticker := w.clock.NewTicker(w.interval, "gitsync", "worker") + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.tick(ctx) + } + } +} + +// Done returns a channel that is closed when the worker exits. +func (w *Worker) Done() <-chan struct{} { + return w.done +} + +func chatDiffStatusFromRow(row database.AcquireStaleChatDiffStatusesRow) database.ChatDiffStatus { + return database.ChatDiffStatus{ + ChatID: row.ChatID, + Url: row.Url, + PullRequestState: row.PullRequestState, + ChangesRequested: row.ChangesRequested, + Additions: row.Additions, + Deletions: row.Deletions, + ChangedFiles: row.ChangedFiles, + RefreshedAt: row.RefreshedAt, + StaleAt: row.StaleAt, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + GitBranch: row.GitBranch, + GitRemoteOrigin: row.GitRemoteOrigin, + } +} + +func (w *Worker) tick(ctx context.Context) { + // Set a context equal to w.interval so that we do not hold up processing due to + // random unicorn-related events. + ctx, cancel := context.WithTimeout(ctx, w.interval) + defer cancel() + + acquiredRows, err := w.store.AcquireStaleChatDiffStatuses(ctx, w.batchSize) + if err != nil { + w.logger.Warn(ctx, "acquire stale chat diff statuses", + slog.Error(err)) + return + } + if len(acquiredRows) == 0 { + return + } + + // Build refresh requests directly from acquired rows. + requests := make([]RefreshRequest, 0, len(acquiredRows)) + for _, row := range acquiredRows { + requests = append(requests, RefreshRequest{ + Row: chatDiffStatusFromRow(row), + OwnerID: row.OwnerID, + }) + } + + results, err := w.refresher.Refresh(ctx, requests) + if err != nil { + w.logger.Warn(ctx, "batch refresh chat diff statuses", + slog.Error(err)) + return + } + + for _, res := range results { + if res.Error != nil { + w.logger.Debug(ctx, "refresh chat diff status", + slog.F("chat_id", res.Request.Row.ChatID), + slog.Error(res.Error)) + // Back off so the row isn't retried immediately. + if err := w.store.BackoffChatDiffStatus(ctx, + database.BackoffChatDiffStatusParams{ + ChatID: res.Request.Row.ChatID, + StaleAt: w.clock.Now().UTC().Add(DiffStatusTTL), + }, + ); err != nil { + w.logger.Warn(ctx, "backoff failed chat diff status", + slog.F("chat_id", res.Request.Row.ChatID), + slog.Error(err)) + } + continue + } + if res.Params == nil { + // No PR yet — skip. + continue + } + if _, err := w.store.UpsertChatDiffStatus(ctx, *res.Params); err != nil { + w.logger.Warn(ctx, "upsert refreshed chat diff status", + slog.F("chat_id", res.Request.Row.ChatID), + slog.Error(err)) + continue + } + if w.publishDiffStatusChangeFn != nil { + if err := w.publishDiffStatusChangeFn(ctx, res.Request.Row.ChatID); err != nil { + w.logger.Debug(ctx, "publish diff status change", + slog.F("chat_id", res.Request.Row.ChatID), + slog.Error(err)) + } + } + } +} + +// MarkStale persists the git ref on all chats for a workspace, +// setting stale_at to the past so the next tick picks them up. +// Publishes a diff status event for each affected chat. +// Called from workspaceagents handlers. No goroutines spawned. +func (w *Worker) MarkStale( + ctx context.Context, + workspaceID, ownerID uuid.UUID, + branch, origin string, +) { + if branch == "" || origin == "" { + return + } + + chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{ + OwnerID: ownerID, + }) + if err != nil { + w.logger.Warn(ctx, "list chats for git ref storage", + slog.F("workspace_id", workspaceID), + slog.Error(err)) + return + } + + for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) { + _, err := w.store.UpsertChatDiffStatusReference(ctx, + database.UpsertChatDiffStatusReferenceParams{ + ChatID: chat.ID, + GitBranch: branch, + GitRemoteOrigin: origin, + StaleAt: w.clock.Now().Add(-time.Second), + Url: sql.NullString{}, + }, + ) + if err != nil { + w.logger.Warn(ctx, "store git ref on chat diff status", + slog.F("chat_id", chat.ID), + slog.F("workspace_id", workspaceID), + slog.Error(err)) + continue + } + // Notify the frontend immediately so the UI shows the + // branch info even before the worker refreshes PR data. + if w.publishDiffStatusChangeFn != nil { + if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil { + w.logger.Debug(ctx, "publish diff status after mark stale", + slog.F("chat_id", chat.ID), slog.Error(pubErr)) + } + } + } +} + +// filterChatsByWorkspaceID returns only chats associated with +// the given workspace. +func filterChatsByWorkspaceID( + chats []database.Chat, + workspaceID uuid.UUID, +) []database.Chat { + filtered := make([]database.Chat, 0, len(chats)) + for _, chat := range chats { + if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID { + continue + } + filtered = append(filtered, chat) + } + return filtered +} diff --git a/coderd/gitsync/worker_test.go b/coderd/gitsync/worker_test.go new file mode 100644 index 0000000000..36e4d4fb54 --- /dev/null +++ b/coderd/gitsync/worker_test.go @@ -0,0 +1,744 @@ +package gitsync_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/externalauth/gitprovider" + "github.com/coder/coder/v2/coderd/gitsync" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +// testRefresherCfg configures newTestRefresher. +type testRefresherCfg struct { + resolveBranchPR func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) + fetchPRStatus func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) +} + +type testRefresherOpt func(*testRefresherCfg) + +func withResolveBranchPR(f func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)) testRefresherOpt { + return func(c *testRefresherCfg) { c.resolveBranchPR = f } +} + +// newTestRefresher creates a Refresher backed by mock +// provider/token resolvers. The provider recognises any origin, +// resolves branches to a canned PR, and returns a canned PRStatus. +func newTestRefresher(t *testing.T, clk quartz.Clock, opts ...testRefresherOpt) *gitsync.Refresher { + t.Helper() + + cfg := testRefresherCfg{ + resolveBranchPR: func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) { + return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil + }, + fetchPRStatus: func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) { + return &gitprovider.PRStatus{ + State: gitprovider.PRStateOpen, + DiffStats: gitprovider.DiffStats{ + Additions: 10, + Deletions: 3, + ChangedFiles: 2, + }, + }, nil + }, + } + for _, o := range opts { + o(&cfg) + } + + prov := &mockProvider{ + parseRepositoryOrigin: func(string) (string, string, string, bool) { + return "owner", "repo", "https://github.com/owner/repo", true + }, + parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) { + return gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, raw != "" + }, + resolveBranchPR: cfg.resolveBranchPR, + fetchPullRequestStatus: cfg.fetchPRStatus, + buildPullRequestURL: func(ref gitprovider.PRRef) string { + return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number) + }, + } + + providers := func(string) gitprovider.Provider { return prov } + tokens := func(context.Context, uuid.UUID, string) (*string, error) { + return ptr.Ref("tok"), nil + } + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + return gitsync.NewRefresher(providers, tokens, logger, clk) +} + +// makeAcquiredRow returns an AcquireStaleChatDiffStatusesRow with +// a non-empty branch/origin so the Refresher goes through the +// branch-resolution path. +func makeAcquiredRow(chatID, ownerID uuid.UUID) database.AcquireStaleChatDiffStatusesRow { + return database.AcquireStaleChatDiffStatusesRow{ + ChatID: chatID, + GitBranch: "feature", + GitRemoteOrigin: "https://github.com/owner/repo", + StaleAt: time.Now().Add(-time.Minute), + OwnerID: ownerID, + } +} + +// tickOnce traps the worker's NewTicker call, starts the worker, +// fires one tick, waits for it to finish by observing the given +// tickDone channel, then shuts the worker down. The tickDone +// channel must be closed when the last expected operation in the +// tick completes. For tests where the tick does nothing (e.g. 0 +// stale rows or store error), tickDone should be closed inside +// acquireStaleChatDiffStatuses. +func tickOnce( + ctx context.Context, + t *testing.T, + mClock *quartz.Mock, + worker *gitsync.Worker, + tickDone <-chan struct{}, +) { + t.Helper() + + trap := mClock.Trap().NewTicker("gitsync", "worker") + defer trap.Close() + + workerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go worker.Start(workerCtx) + + // Wait for the worker to create its ticker. + trap.MustWait(ctx).MustRelease(ctx) + + // Fire one tick. The waiter resolves when the channel receive + // completes, not when w.tick() returns, so we use tickDone to + // know when to proceed. + _, w := mClock.AdvanceNext() + w.MustWait(ctx) + + // Wait for the tick's business logic to finish. + select { + case <-tickDone: + case <-ctx.Done(): + t.Fatal("timed out waiting for tick to complete") + } + + cancel() + <-worker.Done() +} + +func TestWorker_SkipsFreshRows(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + tickDone := make(chan struct{}) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + // No stale rows — tick returns immediately. + close(tickDone) + return nil, nil + }) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) +} + +func TestWorker_LimitsToNRows(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + var capturedLimit atomic.Int32 + var upsertCount atomic.Int32 + ownerID := uuid.New() + const numRows = 5 + tickDone := make(chan struct{}) + + rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows) + for i := range rows { + rows[i] = makeAcquiredRow(uuid.New(), ownerID) + } + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + capturedLimit.Store(limitVal) + return rows, nil + }) + store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) { + upsertCount.Add(1) + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(numRows) + + pub := func(_ context.Context, _ uuid.UUID) error { + if upsertCount.Load() == numRows { + close(tickDone) + } + return nil + } + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) + + // The default batch size is 50. + assert.Equal(t, int32(50), capturedLimit.Load()) + assert.Equal(t, int32(numRows), upsertCount.Load()) +} + +func TestWorker_RefresherReturnsNilNil_SkipsUpsert(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + chatID := uuid.New() + ownerID := uuid.New() + + // When the Refresher returns (nil, nil) the worker skips the + // upsert and publish. We signal tickDone from the refresher + // mock since that is the last operation before the tick + // returns. + tickDone := make(chan struct{}) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + Return([]database.AcquireStaleChatDiffStatusesRow{makeAcquiredRow(chatID, ownerID)}, nil) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + // ResolveBranchPullRequest returns nil → Refresher returns + // (nil, nil). + refresher := newTestRefresher(t, mClock, withResolveBranchPR( + func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) { + close(tickDone) + return nil, nil + }, + )) + + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) +} + +func TestWorker_RefresherError_BacksOffRow(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + chat1 := uuid.New() + chat2 := uuid.New() + ownerID := uuid.New() + + var upsertCount atomic.Int32 + var publishCount atomic.Int32 + var backoffCount atomic.Int32 + var mu sync.Mutex + var backoffArgs []database.BackoffChatDiffStatusParams + tickDone := make(chan struct{}) + var closeOnce sync.Once + + // Two rows processed: one fails (backoff), one succeeds + // (upsert+publish). Both must finish before we close tickDone. + var terminalOps atomic.Int32 + signalIfDone := func() { + if terminalOps.Add(1) == 2 { + closeOnce.Do(func() { close(tickDone) }) + } + } + + mClock := quartz.NewMock(t) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + Return([]database.AcquireStaleChatDiffStatusesRow{ + makeAcquiredRow(chat1, ownerID), + makeAcquiredRow(chat2, ownerID), + }, nil) + store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error { + backoffCount.Add(1) + mu.Lock() + backoffArgs = append(backoffArgs, arg) + mu.Unlock() + signalIfDone() + return nil + }) + store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) { + upsertCount.Add(1) + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }) + + pub := func(_ context.Context, _ uuid.UUID) error { + // Only the successful row publishes. + publishCount.Add(1) + signalIfDone() + return nil + } + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + // Fail ResolveBranchPullRequest for the first call, succeed + // for the second. + var callCount atomic.Int32 + refresher := newTestRefresher(t, mClock, withResolveBranchPR( + func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) { + n := callCount.Add(1) + if n == 1 { + return nil, fmt.Errorf("simulated provider error") + } + return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil + }, + )) + + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) + + // BackoffChatDiffStatus was called for the failed row. + assert.Equal(t, int32(1), backoffCount.Load()) + mu.Lock() + require.Len(t, backoffArgs, 1) + assert.Equal(t, chat1, backoffArgs[0].ChatID) + // stale_at should be approximately clock.Now() + DiffStatusTTL (120s). + expectedStaleAt := mClock.Now().UTC().Add(gitsync.DiffStatusTTL) + assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second) + mu.Unlock() + + // UpsertChatDiffStatus was called for the successful row. + assert.Equal(t, int32(1), upsertCount.Load()) + // PublishDiffStatusChange was called only for the successful row. + assert.Equal(t, int32(1), publishCount.Load()) +} + +func TestWorker_UpsertError_ContinuesNextRow(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + chat1 := uuid.New() + chat2 := uuid.New() + ownerID := uuid.New() + + var publishCount atomic.Int32 + tickDone := make(chan struct{}) + var closeOnce sync.Once + var mu sync.Mutex + upsertedChatIDs := make(map[uuid.UUID]struct{}) + + // We have 2 rows. The upsert for chat1 fails; the upsert + // for chat2 succeeds and publishes. Because goroutines run + // concurrently we don't know which finishes last, so we + // track the total number of "terminal" events (upsert error + // + publish success) and close tickDone when both have + // occurred. + var terminalOps atomic.Int32 + signalIfDone := func() { + if terminalOps.Add(1) == 2 { + closeOnce.Do(func() { close(tickDone) }) + } + } + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + Return([]database.AcquireStaleChatDiffStatusesRow{ + makeAcquiredRow(chat1, ownerID), + makeAcquiredRow(chat2, ownerID), + }, nil) + store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) { + if arg.ChatID == chat1 { + // Terminal event for the failing row. + signalIfDone() + return database.ChatDiffStatus{}, fmt.Errorf("db write error") + } + mu.Lock() + upsertedChatIDs[arg.ChatID] = struct{}{} + mu.Unlock() + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(2) + + pub := func(_ context.Context, _ uuid.UUID) error { + publishCount.Add(1) + // Terminal event for the successful row. + signalIfDone() + return nil + } + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) + + mu.Lock() + _, gotChat2 := upsertedChatIDs[chat2] + mu.Unlock() + assert.True(t, gotChat2, "chat2 should have been upserted") + assert.Equal(t, int32(1), publishCount.Load()) +} + +func TestWorker_RespectsShutdown(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + Return(nil, nil).AnyTimes() + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + trap := mClock.Trap().NewTicker("gitsync", "worker") + defer trap.Close() + + workerCtx, cancel := context.WithCancel(ctx) + go worker.Start(workerCtx) + + // Wait for ticker creation so the worker is running. + trap.MustWait(ctx).MustRelease(ctx) + + // Cancel immediately. + cancel() + + select { + case <-worker.Done(): + // Success — worker shut down. + case <-ctx.Done(): + t.Fatal("timed out waiting for worker to shut down") + } +} + +func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + workspaceID := uuid.New() + ownerID := uuid.New() + chat1 := uuid.New() + chat2 := uuid.New() + chatOther := uuid.New() + + var mu sync.Mutex + var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams + var publishedIDs []uuid.UUID + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) { + require.Equal(t, ownerID, arg.OwnerID) + return []database.Chat{ + {ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}, + {ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}, + {ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}, + }, nil + }) + store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) { + mu.Lock() + upsertRefCalls = append(upsertRefCalls, arg) + mu.Unlock() + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(2) + + pub := func(_ context.Context, chatID uuid.UUID) error { + mu.Lock() + publishedIDs = append(publishedIDs, chatID) + mu.Unlock() + return nil + } + + mClock := quartz.NewMock(t) + now := mClock.Now() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo") + + mu.Lock() + defer mu.Unlock() + + require.Len(t, upsertRefCalls, 2) + for _, call := range upsertRefCalls { + assert.Equal(t, "feature", call.GitBranch) + assert.Equal(t, "https://github.com/owner/repo", call.GitRemoteOrigin) + assert.True(t, call.StaleAt.Before(now), + "stale_at should be in the past, got %v vs now %v", call.StaleAt, now) + assert.Equal(t, sql.NullString{}, call.Url) + } + + require.Len(t, publishedIDs, 2) + assert.ElementsMatch(t, []uuid.UUID{chat1, chat2}, publishedIDs) +} + +func TestWorker_MarkStale_NoMatchingChats(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + workspaceID := uuid.New() + ownerID := uuid.New() + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()). + Return([]database.Chat{ + {ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}, + {ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}}, + }, nil) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y") +} + +func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + workspaceID := uuid.New() + ownerID := uuid.New() + chat1 := uuid.New() + chat2 := uuid.New() + + var publishCount atomic.Int32 + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()). + Return([]database.Chat{ + {ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}, + {ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}, + }, nil) + store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) { + if arg.ChatID == chat1 { + return database.ChatDiffStatus{}, fmt.Errorf("upsert ref error") + } + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(2) + + pub := func(_ context.Context, _ uuid.UUID) error { + publishCount.Add(1) + return nil + } + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b") + + assert.Equal(t, int32(1), publishCount.Load()) +} + +func TestWorker_MarkStale_GetChatsFails(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("db error")) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y") +} + +func TestWorker_TickStoreError(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + tickDone := make(chan struct{}) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()). + DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) { + close(tickDone) + return nil, fmt.Errorf("database unavailable") + }) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) +} + +func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + branch string + origin string + }{ + {"both empty", "", ""}, + {"branch empty", "", "https://github.com/x/y"}, + {"origin empty", "main", ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) + + worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin) + }) + } +} + +// TestWorker exercises the worker tick against a +// real PostgreSQL database to verify that the SQL queries, foreign key +// constraints, and upsert logic work end-to-end. +func TestWorker(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + // 1. Real database store. + db, _ := dbtestutil.NewDB(t) + + // 2. Create a user (FK for chats). + user := dbgen.User(t, db, database.User{}) + + // 3. Set up FK chain: chat_providers -> chat_model_configs -> chats. + _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + Enabled: true, + }) + require.NoError(t, err) + + modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "test-model", + DisplayName: "Test Model", + Enabled: true, + ContextLimit: 100000, + CompressionThreshold: 70, + Options: json.RawMessage("{}"), + }) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "integration-test", + }) + require.NoError(t, err) + + // 4. Seed a stale diff status row so the worker picks it up. + _, err = db.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{ + ChatID: chat.ID, + GitBranch: "feature", + GitRemoteOrigin: "https://github.com/o/r", + StaleAt: time.Now().Add(-time.Minute), + Url: sql.NullString{}, + }) + require.NoError(t, err) + + // 5. Mock refresher returns a canned PR status. + mClock := quartz.NewMock(t) + refresher := newTestRefresher(t, mClock) + + // 6. Track publish calls. + var publishCount atomic.Int32 + tickDone := make(chan struct{}) + pub := func(_ context.Context, chatID uuid.UUID) error { + assert.Equal(t, chat.ID, chatID) + if publishCount.Add(1) == 1 { + close(tickDone) + } + return nil + } + + // 7. Create and run the worker for one tick. + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + worker := gitsync.NewWorker(db, refresher, pub, mClock, logger) + + tickOnce(ctx, t, mClock, worker, tickDone) + + // 8. Assert publisher was called. + require.Equal(t, int32(1), publishCount.Load()) + + // 9. Read back and verify persisted fields. + status, err := db.GetChatDiffStatusByChatID(ctx, chat.ID) + require.NoError(t, err) + + // The mock resolveBranchPR returns PRRef{Owner: "o", Repo: "r", Number: 1} + // and buildPullRequestURL formats it as https://github.com/o/r/pull/1. + assert.Equal(t, "https://github.com/o/r/pull/1", status.Url.String) + assert.True(t, status.Url.Valid) + assert.Equal(t, string(gitprovider.PRStateOpen), status.PullRequestState.String) + assert.True(t, status.PullRequestState.Valid) + assert.Equal(t, int32(10), status.Additions) + assert.Equal(t, int32(3), status.Deletions) + assert.Equal(t, int32(2), status.ChangedFiles) + assert.True(t, status.RefreshedAt.Valid, "refreshed_at should be set") + // The mock clock's Now() + DiffStatusTTL determines stale_at. + expectedStaleAt := mClock.Now().Add(gitsync.DiffStatusTTL) + assert.WithinDuration(t, expectedStaleAt, status.StaleAt, time.Second) +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index c92076ede2..27719cfdea 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1835,18 +1835,6 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ Branch: strings.TrimSpace(query.Get("git_branch")), RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")), } - var chatID uuid.NullUUID - if rawChatID := query.Get("chat_id"); rawChatID != "" { - parsed, err := uuid.Parse(rawChatID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid chat_id.", - Detail: err.Error(), - }) - return - } - chatID = uuid.NullUUID{UUID: parsed, Valid: true} - } // Either match or configID must be provided! match := query.Get("match") if match == "" { @@ -1940,11 +1928,12 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ return } - // Persist git refs as soon as the agent requests external auth so branch + // MarkStale will trigger a refresh by coderd/gitsync. This allows us to + // persist git refs as soon as the agent requests external auth so branch // context is retained even if the flow requires an out-of-band login. - if gitRef.Branch != "" || gitRef.RemoteOrigin != "" { - //nolint:gocritic // System context required to persist chat git refs. - api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef) + if gitRef.Branch != "" && gitRef.RemoteOrigin != "" { + //nolint:gocritic // Chat processor context required for cross-user chat lookup + api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin) } var previousToken *database.ExternalAuthLink @@ -1960,7 +1949,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ return } - api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef) + api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef) } // This is the URL that will redirect the user with a state token. @@ -2018,11 +2007,10 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ }) return } - api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef) httpapi.Write(ctx, rw, http.StatusOK, resp) } -func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) { +func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) { // Since we're ticking frequently and this sign-in operation is rare, // we are OK with polling to avoid the complexity of pubsub. ticker, done := api.NewTicker(time.Second) @@ -2092,7 +2080,9 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R }) return } - api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef) + // MarkStale will trigger a refresh by coderd/gitsync. + //nolint:gocritic // Chat processor context required for cross-user chat lookup + api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin) httpapi.Write(ctx, rw, http.StatusOK, resp) return } diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 782ef7e383..248e82685e 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -980,6 +980,10 @@ type ExternalAuthConfig struct { // 'Username for "https://github.com":' // And sending it to the Coder server to match against the Regex. Regex string `json:"regex" yaml:"regex"` + // APIBaseURL is the base URL for provider REST API calls + // (e.g., "https://api.github.com" for GitHub). Derived from + // defaults when not explicitly configured. + APIBaseURL string `json:"api_base_url" yaml:"api_base_url"` // DisplayName is shown in the UI to identify the auth config. DisplayName string `json:"display_name" yaml:"display_name"` // DisplayIcon is a URL to an icon to display in the UI. diff --git a/codersdk/testdata/githubcfg.yaml b/codersdk/testdata/githubcfg.yaml index 838d8f0c2e..86bfaf4eb1 100644 --- a/codersdk/testdata/githubcfg.yaml +++ b/codersdk/testdata/githubcfg.yaml @@ -22,6 +22,7 @@ externalAuthProviders: mcp_tool_allow_regex: .* mcp_tool_deny_regex: create_gist regex: ^https://example.com/.*$ + api_base_url: "" display_name: GitHub display_icon: /static/icons/github.svg code_challenge_methods_supported: diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 98a4bde250..921aeaef1c 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -279,6 +279,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "external_auth": { "value": [ { + "api_base_url": "string", "app_install_url": "string", "app_installations_url": "string", "auth_url": "string", diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index cd67ea783d..6fa3bb121c 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -2786,6 +2786,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "external_auth": { "value": [ { + "api_base_url": "string", "app_install_url": "string", "app_installations_url": "string", "auth_url": "string", @@ -3357,6 +3358,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "external_auth": { "value": [ { + "api_base_url": "string", "app_install_url": "string", "app_installations_url": "string", "auth_url": "string", @@ -4104,6 +4106,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ```json { + "api_base_url": "string", "app_install_url": "string", "app_installations_url": "string", "auth_url": "string", @@ -4133,22 +4136,23 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o ### Properties -| Name | Type | Required | Restrictions | Description | -|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------| -| `app_install_url` | string | false | | | -| `app_installations_url` | string | false | | | -| `auth_url` | string | false | | | -| `client_id` | string | false | | | -| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". | -| `device_code_url` | string | false | | | -| `device_flow` | boolean | false | | | -| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. | -| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. | -| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. | -| `mcp_tool_allow_regex` | string | false | | | -| `mcp_tool_deny_regex` | string | false | | | -| `mcp_url` | string | false | | | -| `no_refresh` | boolean | false | | | +| Name | Type | Required | Restrictions | Description | +|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `api_base_url` | string | false | | Api base URL is the base URL for provider REST API calls (e.g., "https://api.github.com" for GitHub). Derived from defaults when not explicitly configured. | +| `app_install_url` | string | false | | | +| `app_installations_url` | string | false | | | +| `auth_url` | string | false | | | +| `client_id` | string | false | | | +| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". | +| `device_code_url` | string | false | | | +| `device_flow` | boolean | false | | | +| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. | +| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. | +| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. | +| `mcp_tool_allow_regex` | string | false | | | +| `mcp_tool_deny_regex` | string | false | | | +| `mcp_url` | string | false | | | +| `no_refresh` | boolean | false | | | |`regex`|string|false||Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type. Git clone makes use of this by parsing the URL from: 'Username for "https://github.com":' And sending it to the Coder server to match against the Regex.| |`revoke_url`|string|false||| @@ -14182,6 +14186,7 @@ None { "value": [ { + "api_base_url": "string", "app_install_url": "string", "app_installations_url": "string", "auth_url": "string", diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 7f6aadcf40..53c30e2ebe 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2690,6 +2690,12 @@ export interface ExternalAuthConfig { * And sending it to the Coder server to match against the Regex. */ readonly regex: string; + /** + * APIBaseURL is the base URL for provider REST API calls + * (e.g., "https://api.github.com" for GitHub). Derived from + * defaults when not explicitly configured. + */ + readonly api_base_url: string; /** * DisplayName is shown in the UI to identify the auth config. */ diff --git a/site/src/pages/DeploymentSettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx b/site/src/pages/DeploymentSettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx index a37c490fc0..d4d8059148 100644 --- a/site/src/pages/DeploymentSettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx +++ b/site/src/pages/DeploymentSettingsPage/ExternalAuthSettingsPage/ExternalAuthSettingsPageView.stories.tsx @@ -12,6 +12,7 @@ const meta: Meta = { type: "GitHub", client_id: "client_id", regex: "regex", + api_base_url: "", auth_url: "", token_url: "", validate_url: "",