feat(coderd): refactors github pr sync functionality (#22715)

- Adds `_API_BASE_URL` to `CODER_EXTERNAL_AUTH_CONFIG_`
- Extracts and refactors existing GitHub PR sync logic to new packages
`coderd/gitsync` and `coderd/externalauth/gitprovider`
- Associated wiring and tests

Created using Opus 4.6
This commit is contained in:
Cian Johnston
2026-03-10 18:46:01 +00:00
committed by GitHub
parent cbe46c816e
commit bc27274aba
31 changed files with 4311 additions and 708 deletions
+2
View File
@@ -2909,6 +2909,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
provider.MCPToolDenyRegex = v.Value provider.MCPToolDenyRegex = v.Value
case "PKCE_METHODS": case "PKCE_METHODS":
provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ") provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ")
case "API_BASE_URL":
provider.APIBaseURL = v.Value
} }
providers[providerNum] = provider providers[providerNum] = provider
} }
+23
View File
@@ -108,6 +108,29 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
}) })
} }
func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
"CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL)
}
func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "", providers[0].APIBaseURL)
}
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_` // TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
// environment variables are still supported. // environment variables are still supported.
func TestReadGitAuthProvidersFromEnv(t *testing.T) { func TestReadGitAuthProvidersFromEnv(t *testing.T) {
+4
View File
@@ -15269,6 +15269,10 @@ const docTemplate = `{
"codersdk.ExternalAuthConfig": { "codersdk.ExternalAuthConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
"api_base_url": {
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
"type": "string"
},
"app_install_url": { "app_install_url": {
"type": "string" "type": "string"
}, },
+4
View File
@@ -13792,6 +13792,10 @@
"codersdk.ExternalAuthConfig": { "codersdk.ExternalAuthConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
"api_base_url": {
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
"type": "string"
},
"app_install_url": { "app_install_url": {
"type": "string" "type": "string"
}, },
+177 -668
View File
@@ -13,7 +13,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -32,6 +31,8 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/httpmw"
@@ -39,16 +40,15 @@ import (
"github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket" "github.com/coder/websocket"
) )
const ( const (
chatDiffStatusTTL = 120 * time.Second chatDiffStatusTTL = gitsync.DiffStatusTTL
chatDiffBackgroundRefreshTimeout = 20 * time.Second chatStreamBatchSize = 256
githubAPIBaseURL = "https://api.github.com"
chatStreamBatchSize = 256
chatContextLimitModelConfigKey = "context_limit" chatContextLimitModelConfigKey = "context_limit"
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold" chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
@@ -58,19 +58,6 @@ const (
maxSystemPromptLenBytes = 131072 // 128 KiB maxSystemPromptLenBytes = 131072 // 128 KiB
) )
// chatDiffRefreshBackoffSchedule defines the delays between successive
// background diff refresh attempts. The trigger fires when the agent
// obtains a GitHub token, which is typically right before a git push
// or PR creation. The backoff gives progressively more time for the
// push and any PR workflow to complete before querying the GitHub API.
var chatDiffRefreshBackoffSchedule = []time.Duration{
1 * time.Second,
3 * time.Second,
5 * time.Second,
10 * time.Second,
20 * time.Second,
}
// chatGitRef holds the branch and remote origin reported by the // chatGitRef holds the branch and remote origin reported by the
// workspace agent during a git operation. // workspace agent during a git operation.
type chatGitRef struct { type chatGitRef struct {
@@ -78,32 +65,6 @@ type chatGitRef struct {
RemoteOrigin string RemoteOrigin string
} }
var (
githubPullRequestPathPattern = regexp.MustCompile(
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
)
githubRepositoryHTTPSPattern = regexp.MustCompile(
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
)
githubRepositorySSHPathPattern = regexp.MustCompile(
`^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
)
)
type githubPullRequestRef struct {
Owner string
Repo string
Number int
}
type githubPullRequestStatus struct {
PullRequestState string
ChangesRequested bool
Additions int32
Deletions int32
ChangedFiles int32
}
type chatRepositoryRef struct { type chatRepositoryRef struct {
Provider string Provider string
RemoteOrigin string RemoteOrigin string
@@ -1249,193 +1210,6 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time,
return chatDiffStatusIsStale(status, now) return chatDiffStatusIsStale(status, now)
} }
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
return
}
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
ctx := api.ctx
if ctx == nil {
ctx = context.Background()
}
//nolint:gocritic // Background goroutine for diff status refresh has no user context.
ctx = dbauthz.AsSystemRestricted(ctx)
// Always store the git ref so the data is persisted even
// before a PR exists. The frontend can show branch info
// and the refresh loop can resolve a PR later.
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
for _, delay := range chatDiffRefreshBackoffSchedule {
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
}
// Refresh and publish status on every iteration.
// Stop the loop once a PR is discovered — there's
// nothing more to wait for after that.
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
return
}
}
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
}
// storeChatGitRef persists the git branch and remote origin reported
// by the workspace agent on the chat that initiated the git operation.
// When chatID is set, only that specific chat is updated; otherwise all
// chats associated with the workspace are updated (legacy fallback).
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
var chatsToUpdate []database.Chat
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: workspaceOwnerID,
})
if err != nil {
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
}
for _, chat := range chatsToUpdate {
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: gitRef.Branch,
GitRemoteOrigin: gitRef.RemoteOrigin,
StaleAt: time.Now().UTC().Add(-time.Second),
Url: sql.NullString{},
})
if err != nil {
api.Logger.Warn(ctx, "failed to store git ref on chat diff status",
slog.F("chat_id", chat.ID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
continue
}
api.publishChatDiffStatusEvent(ctx, chat.ID)
}
}
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
// associated with the given workspace. When chatID is set, only that
// specific chat is refreshed; otherwise all chats for the workspace
// are refreshed (legacy fallback). It returns true when every
// refreshed chat has a PR URL resolved, signaling that the caller
// can stop polling.
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
var filtered []database.Chat
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return false
}
filtered = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: workspaceOwnerID,
})
if err != nil {
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
slog.F("workspace_id", workspaceID),
slog.F("workspace_owner_id", workspaceOwnerID),
slog.Error(err),
)
return false
}
filtered = filterChatsByWorkspaceID(chats, workspaceID)
}
if len(filtered) == 0 {
return false
}
allHavePR := true
for _, chat := range filtered {
refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout)
status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true)
cancel()
if err != nil {
api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth",
slog.F("workspace_id", workspaceID),
slog.F("chat_id", chat.ID),
slog.Error(err),
)
allHavePR = false
} else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" {
allHavePR = false
}
api.publishChatStatusEvent(ctx, chat.ID)
api.publishChatDiffStatusEvent(ctx, chat.ID)
}
return allHavePR
}
func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat {
filteredChats := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
continue
}
filteredChats = append(filteredChats, chat)
}
return filteredChats
}
func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) {
if api.chatDaemon == nil {
return
}
if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil {
api.Logger.Debug(ctx, "failed to refresh published chat status",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) {
if api.chatDaemon == nil {
return
}
if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil {
api.Logger.Debug(ctx, "failed to publish chat diff status change",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
func (api *API) resolveChatDiffContents( func (api *API) resolveChatDiffContents(
ctx context.Context, ctx context.Context,
chat database.Chat, chat database.Chat,
@@ -1483,22 +1257,36 @@ func (api *API) resolveChatDiffContents(
if reference.RepositoryRef == nil { if reference.RepositoryRef == nil {
return result, nil return result, nil
} }
if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
if gp == nil {
return result, nil return result, nil
} }
token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID) token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
if err != nil {
return result, xerrors.Errorf("resolve git access token: %w", err)
} else if token == nil {
return result, xerrors.New("nil git access token")
}
if reference.PullRequestURL != "" { if reference.PullRequestURL != "" {
diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token) ref, ok := gp.ParsePullRequestURL(reference.PullRequestURL)
if !ok {
return result, xerrors.Errorf("invalid pull request URL %q", reference.PullRequestURL)
}
diff, err := gp.FetchPullRequestDiff(ctx, *token, ref)
if err != nil { if err != nil {
return result, err return result, err
} }
result.Diff = diff result.Diff = diff
return result, nil return result, nil
} }
diff, err := gp.FetchBranchDiff(ctx, *token, gitprovider.BranchRef{
diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token) Owner: reference.RepositoryRef.Owner,
Repo: reference.RepositoryRef.Repo,
Branch: reference.RepositoryRef.Branch,
})
if err != nil { if err != nil {
return result, err return result, err
} }
@@ -1532,34 +1320,53 @@ func (api *API) resolveChatDiffReference(
// If we have a repo ref with a branch, try to resolve the // If we have a repo ref with a branch, try to resolve the
// current open PR. This picks up new PRs after the previous // current open PR. This picks up new PRs after the previous
// one was closed. // one was closed.
if reference.RepositoryRef != nil && if reference.RepositoryRef != nil && reference.RepositoryRef.Owner != "" {
strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) { gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef) if gp != nil {
if lookupErr != nil { token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference", if token == nil || errors.Is(err, gitsync.ErrNoTokenAvailable) {
slog.F("chat_id", chat.ID), // No token available yet.
slog.F("provider", reference.RepositoryRef.Provider), return reference, nil
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin), } else if err != nil {
slog.F("branch", reference.RepositoryRef.Branch), return chatDiffReference{}, xerrors.Errorf("resolve git access token: %w", err)
slog.Error(lookupErr), }
) prRef, lookupErr := gp.ResolveBranchPullRequest(ctx, *token, gitprovider.BranchRef{
} else if pullRequestURL != "" { Owner: reference.RepositoryRef.Owner,
reference.PullRequestURL = pullRequestURL Repo: reference.RepositoryRef.Repo,
Branch: reference.RepositoryRef.Branch,
})
if lookupErr != nil {
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
slog.F("chat_id", chat.ID),
slog.F("provider", reference.RepositoryRef.Provider),
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
slog.F("branch", reference.RepositoryRef.Branch),
slog.Error(lookupErr),
)
} else if prRef != nil {
reference.PullRequestURL = gp.BuildPullRequestURL(*prRef)
}
reference.PullRequestURL = gp.NormalizePullRequestURL(reference.PullRequestURL)
} }
} }
reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL)
// If we have a PR URL but no repo ref (e.g. the agent hasn't // If we have a PR URL but no repo ref (e.g. the agent hasn't
// reported branch/origin yet), derive a partial ref from the // reported branch/origin yet), derive a partial ref from the
// PR URL so the caller can still show provider/owner/repo. // PR URL so the caller can still show provider/owner/repo.
if reference.RepositoryRef == nil && reference.PullRequestURL != "" { if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok { for _, extAuth := range api.ExternalAuthConfigs {
reference.RepositoryRef = &chatRepositoryRef{ gp := extAuth.Git(api.HTTPClient)
Provider: string(codersdk.EnhancedExternalAuthProviderGitHub), if gp == nil {
RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo), continue
Owner: parsed.Owner, }
Repo: parsed.Repo, if parsed, ok := gp.ParsePullRequestURL(reference.PullRequestURL); ok {
reference.RepositoryRef = &chatRepositoryRef{
Provider: strings.ToLower(extAuth.Type),
Owner: parsed.Owner,
Repo: parsed.Repo,
RemoteOrigin: gp.BuildRepositoryURL(parsed.Owner, parsed.Repo),
}
break
} }
} }
} }
@@ -1577,19 +1384,18 @@ func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus)
return nil return nil
} }
providerType, gp := api.resolveExternalAuth(origin)
repoRef := &chatRepositoryRef{ repoRef := &chatRepositoryRef{
Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)), Provider: providerType,
RemoteOrigin: origin, RemoteOrigin: origin,
Branch: branch, Branch: branch,
} }
if gp != nil {
if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok { if owner, repo, normalizedOrigin, ok := gp.ParseRepositoryOrigin(repoRef.RemoteOrigin); ok {
if repoRef.Provider == "" { repoRef.RemoteOrigin = normalizedOrigin
repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub) repoRef.Owner = owner
repoRef.Repo = repo
} }
repoRef.RemoteOrigin = normalizedOrigin
repoRef.Owner = owner
repoRef.Repo = repo
} }
if repoRef.Provider == "" { if repoRef.Provider == "" {
@@ -1643,60 +1449,31 @@ func (api *API) getCachedChatDiffStatus(
) )
} }
func (api *API) resolveExternalAuthProviderType(match string) string { // resolveExternalAuth finds the external auth config matching the
match = strings.TrimSpace(match) // given remote origin URL and returns both the provider type string
if match == "" { // (e.g. "github") and the gitprovider.Provider. Returns ("", nil)
return "" // if no matching config is found.
func (api *API) resolveExternalAuth(origin string) (providerType string, gp gitprovider.Provider) {
origin = strings.TrimSpace(origin)
if origin == "" {
return "", nil
} }
for _, extAuth := range api.ExternalAuthConfigs { for _, extAuth := range api.ExternalAuthConfigs {
if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) { if extAuth.Regex == nil || !extAuth.Regex.MatchString(origin) {
continue continue
} }
return strings.ToLower(strings.TrimSpace(extAuth.Type)) return strings.ToLower(strings.TrimSpace(extAuth.Type)),
extAuth.Git(api.HTTPClient)
} }
return "", nil
return ""
} }
func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) { // resolveGitProvider finds the external auth config matching the
raw = strings.TrimSpace(raw) // given remote origin URL and returns its git provider. Returns
if raw == "" { // nil if no matching git provider is configured.
return "", "", "", false func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
} _, gp := api.resolveExternalAuth(origin)
return gp
matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw)
if len(matches) != 3 {
matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw)
}
if len(matches) != 3 {
return "", "", "", false
}
owner = strings.TrimSpace(matches[1])
repo = strings.TrimSpace(matches[2])
repo = strings.TrimSuffix(repo, ".git")
if owner == "" || repo == "" {
return "", "", "", false
}
return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true
}
func buildGitHubBranchURL(owner string, repo string, branch string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
branch = strings.TrimSpace(branch)
if owner == "" || repo == "" || branch == "" {
return ""
}
return fmt.Sprintf(
"https://github.com/%s/%s/tree/%s",
owner,
repo,
url.PathEscape(branch),
)
} }
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool { func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
@@ -1712,11 +1489,32 @@ func (api *API) refreshChatDiffStatus(
chatID uuid.UUID, chatID uuid.UUID,
pullRequestURL string, pullRequestURL string,
) (database.ChatDiffStatus, error) { ) (database.ChatDiffStatus, error) {
status, err := api.fetchGitHubPullRequestStatus( // Find a provider that can handle this PR URL.
ctx, var gp gitprovider.Provider
pullRequestURL, var ref gitprovider.PRRef
api.resolveChatGitHubAccessToken(ctx, chatOwnerID), for _, extAuth := range api.ExternalAuthConfigs {
) p := extAuth.Git(api.HTTPClient)
if p == nil {
continue
}
if parsed, ok := p.ParsePullRequestURL(pullRequestURL); ok {
gp = p
ref = parsed
break
}
}
if gp == nil {
return database.ChatDiffStatus{}, xerrors.Errorf("no git provider found for PR URL %q", pullRequestURL)
}
origin := gp.BuildRepositoryURL(ref.Owner, ref.Repo)
token, err := api.resolveChatGitAccessToken(ctx, chatOwnerID, origin)
if err != nil {
return database.ChatDiffStatus{}, xerrors.Errorf("resolve git access token: %w", err)
} else if token == nil {
return database.ChatDiffStatus{}, xerrors.New("nil git access token")
}
status, err := gp.FetchPullRequestStatus(ctx, *token, ref)
if err != nil { if err != nil {
return database.ChatDiffStatus{}, err return database.ChatDiffStatus{}, err
} }
@@ -1728,13 +1526,13 @@ func (api *API) refreshChatDiffStatus(
ChatID: chatID, ChatID: chatID,
Url: sql.NullString{String: pullRequestURL, Valid: true}, Url: sql.NullString{String: pullRequestURL, Valid: true},
PullRequestState: sql.NullString{ PullRequestState: sql.NullString{
String: status.PullRequestState, String: string(status.State),
Valid: status.PullRequestState != "", Valid: status.State != "",
}, },
ChangesRequested: status.ChangesRequested, ChangesRequested: status.ChangesRequested,
Additions: status.Additions, Additions: status.DiffStats.Additions,
Deletions: status.Deletions, Deletions: status.DiffStats.Deletions,
ChangedFiles: status.ChangedFiles, ChangedFiles: status.DiffStats.ChangedFiles,
RefreshedAt: refreshedAt, RefreshedAt: refreshedAt,
StaleAt: refreshedAt.Add(chatDiffStatusTTL), StaleAt: refreshedAt.Add(chatDiffStatusTTL),
}, },
@@ -1745,23 +1543,49 @@ func (api *API) refreshChatDiffStatus(
return refreshedStatus, nil return refreshedStatus, nil
} }
func (api *API) resolveChatGitHubAccessToken( func (api *API) resolveChatGitAccessToken(
ctx context.Context, ctx context.Context,
userID uuid.UUID, userID uuid.UUID,
) string { origin string,
// Build a map of provider ID -> config so we can refresh tokens ) (*string, error) {
// using the same code path as provisionerdserver. origin = strings.TrimSpace(origin)
ghConfigs := make(map[string]*externalauth.Config)
providerIDs := []string{"github"} // If we have an origin, find the specific matching config first.
for _, config := range api.ExternalAuthConfigs { // This ensures multi-provider setups (github.com + GHE) get the
if !strings.EqualFold( // correct token.
config.Type, if origin != "" {
string(codersdk.EnhancedExternalAuthProviderGitHub), for _, config := range api.ExternalAuthConfigs {
) { if config.Regex == nil || !config.Regex.MatchString(origin) {
continue continue
}
link, err := api.Database.GetExternalAuthLink(ctx,
database.GetExternalAuthLinkParams{
ProviderID: config.ID,
UserID: userID,
},
)
if err != nil {
continue
}
refreshed, refreshErr := config.RefreshToken(ctx, api.Database, link)
if refreshErr == nil {
link = refreshed
}
token := strings.TrimSpace(link.OAuthAccessToken)
if token != "" {
return ptr.Ref(token), nil
}
} }
}
// Fallback: iterate all external auth configs.
// Used when origin is empty (inline refresh from HTTP handler)
// or when the origin-specific lookup above failed.
configs := make(map[string]*externalauth.Config)
providerIDs := []string{}
for _, config := range api.ExternalAuthConfigs {
providerIDs = append(providerIDs, config.ID) providerIDs = append(providerIDs, config.ID)
ghConfigs[config.ID] = config configs[config.ID] = config
} }
seen := map[string]struct{}{} seen := map[string]struct{}{}
@@ -1785,7 +1609,7 @@ func (api *API) resolveChatGitHubAccessToken(
// Refresh the token if there is a matching config, mirroring // Refresh the token if there is a matching config, mirroring
// the same code path used by provisionerdserver when handing // the same code path used by provisionerdserver when handing
// tokens to provisioners. // tokens to provisioners.
if cfg, ok := ghConfigs[providerID]; ok { if cfg, ok := configs[providerID]; ok {
refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link) refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link)
if refreshErr != nil { if refreshErr != nil {
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff", api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
@@ -1802,336 +1626,11 @@ func (api *API) resolveChatGitHubAccessToken(
token := strings.TrimSpace(link.OAuthAccessToken) token := strings.TrimSpace(link.OAuthAccessToken)
if token != "" { if token != "" {
return token return ptr.Ref(token), nil
} }
} }
return "" return nil, gitsync.ErrNoTokenAvailable
}
func (api *API) resolveGitHubPullRequestURLFromRepositoryRef(
ctx context.Context,
userID uuid.UUID,
repositoryRef chatRepositoryRef,
) (string, error) {
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
return "", nil
}
query := url.Values{}
query.Set("state", "open")
query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch))
query.Set("sort", "updated")
query.Set("direction", "desc")
query.Set("per_page", "1")
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls?%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
query.Encode(),
)
var pulls []struct {
HTMLURL string `json:"html_url"`
}
token := api.resolveChatGitHubAccessToken(ctx, userID)
if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil {
return "", err
}
if len(pulls) == 0 {
return "", nil
}
return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil
}
func (api *API) fetchGitHubPullRequestDiff(
ctx context.Context,
pullRequestURL string,
token string,
) (string, error) {
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
if !ok {
return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL)
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
githubAPIBaseURL,
ref.Owner,
ref.Repo,
ref.Number,
)
return api.fetchGitHubDiff(ctx, requestURL, token)
}
func (api *API) fetchGitHubCompareDiff(
ctx context.Context,
repositoryRef chatRepositoryRef,
token string,
) (string, error) {
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
return "", nil
}
var repository struct {
DefaultBranch string `json:"default_branch"`
}
repositoryURL := fmt.Sprintf(
"%s/repos/%s/%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
)
if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil {
return "", err
}
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
if defaultBranch == "" {
return "", xerrors.New("github repository default branch is empty")
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/compare/%s...%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
url.PathEscape(defaultBranch),
url.PathEscape(repositoryRef.Branch),
)
return api.fetchGitHubDiff(ctx, requestURL, token)
}
func (api *API) fetchGitHubDiff(
ctx context.Context,
requestURL string,
token string,
) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return "", xerrors.Errorf("create github diff request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.diff")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
httpClient := api.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return "", xerrors.Errorf("execute github diff request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
}
return "", xerrors.Errorf(
"github diff request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", xerrors.Errorf("read github diff response: %w", err)
}
return string(diff), nil
}
func (api *API) fetchGitHubPullRequestStatus(
ctx context.Context,
pullRequestURL string,
token string,
) (githubPullRequestStatus, error) {
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
if !ok {
return githubPullRequestStatus{}, xerrors.Errorf(
"invalid GitHub pull request URL %q",
pullRequestURL,
)
}
pullEndpoint := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
githubAPIBaseURL,
ref.Owner,
ref.Repo,
ref.Number,
)
var pull struct {
State string `json:"state"`
Additions int32 `json:"additions"`
Deletions int32 `json:"deletions"`
ChangedFiles int32 `json:"changed_files"`
}
if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil {
return githubPullRequestStatus{}, err
}
var reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
if err := api.decodeGitHubJSON(
ctx,
pullEndpoint+"/reviews?per_page=100",
token,
&reviews,
); err != nil {
return githubPullRequestStatus{}, err
}
return githubPullRequestStatus{
PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)),
ChangesRequested: hasOutstandingGitHubChangesRequested(reviews),
Additions: pull.Additions,
Deletions: pull.Deletions,
ChangedFiles: pull.ChangedFiles,
}, nil
}
func (api *API) decodeGitHubJSON(
ctx context.Context,
requestURL string,
token string,
dest any,
) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return xerrors.Errorf("create github request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff-status")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
httpClient := api.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return xerrors.Errorf("execute github request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return xerrors.Errorf(
"github request failed with status %d",
resp.StatusCode,
)
}
return xerrors.Errorf(
"github request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
return xerrors.Errorf("decode github response: %w", err)
}
return nil
}
func hasOutstandingGitHubChangesRequested(
reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
},
) bool {
type reviewerState struct {
reviewID int64
state string
}
statesByReviewer := make(map[string]reviewerState)
for _, review := range reviews {
login := strings.ToLower(strings.TrimSpace(review.User.Login))
if login == "" {
continue
}
state := strings.ToUpper(strings.TrimSpace(review.State))
switch state {
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
default:
continue
}
current, exists := statesByReviewer[login]
if exists && current.reviewID > review.ID {
continue
}
statesByReviewer[login] = reviewerState{
reviewID: review.ID,
state: state,
}
}
for _, state := range statesByReviewer {
if state.state == "CHANGES_REQUESTED" {
return true
}
}
return false
}
func normalizeGitHubPullRequestURL(raw string) string {
ref, ok := parseGitHubPullRequestURL(strings.TrimRight(
strings.TrimSpace(raw),
"),.;",
))
if !ok {
return ""
}
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
}
func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) {
matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
if len(matches) != 4 {
return githubPullRequestRef{}, false
}
number, err := strconv.Atoi(matches[3])
if err != nil {
return githubPullRequestRef{}, false
}
return githubPullRequestRef{
Owner: matches[1],
Repo: matches[2],
Number: number,
}, true
} }
type createChatWorkspaceSelection struct { type createChatWorkspaceSelection struct {
@@ -2786,11 +2285,21 @@ func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) co
} }
} }
if result.URL == nil { if result.URL == nil {
owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin) // Try to build a branch URL from the stored origin.
if ok { // Since convertChatDiffStatus does not have access to
branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch) // the API instance, we construct a GitHub provider
if branchURL != "" { // directly as a best-effort fallback.
result.URL = &branchURL // TODO: This uses the default github.com API base URL,
// so branch URLs for GitHub Enterprise instances will
// be incorrect. To fix this, convertChatDiffStatus
// would need access to the external auth configs.
gp := gitprovider.New("github", "", nil)
if gp != nil {
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
}
} }
} }
} }
+1 -1
View File
@@ -2605,7 +2605,7 @@ func TestGetChatDiffStatus(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID) require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID)
require.NotNil(t, cachedStatus.URL) require.NotNil(t, cachedStatus.URL)
require.Equal(t, "https://github.com/coder/coder/tree/feature%2Fdiff-status", *cachedStatus.URL) require.Equal(t, "https://github.com/coder/coder/tree/feature/diff-status", *cachedStatus.URL)
require.NotNil(t, cachedStatus.PullRequestState) require.NotNil(t, cachedStatus.PullRequestState)
require.Equal(t, "open", *cachedStatus.PullRequestState) require.Equal(t, "open", *cachedStatus.PullRequestState)
require.True(t, cachedStatus.ChangesRequested) require.True(t, cachedStatus.ChangesRequested)
+26
View File
@@ -61,6 +61,7 @@ import (
"github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/files" "github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/gitsshkey" "github.com/coder/coder/v2/coderd/gitsshkey"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/healthcheck" "github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth" "github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
@@ -773,6 +774,21 @@ func New(options *Options) *API {
Pubsub: options.Pubsub, Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher, WebpushDispatcher: options.WebPushDispatcher,
}) })
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
api.resolveGitProvider,
api.resolveChatGitAccessToken,
gitSyncLogger.Named("refresher"),
quartz.NewReal(),
)
api.gitSyncWorker = gitsync.NewWorker(options.Database,
refresher,
api.chatDaemon.PublishDiffStatusChange,
quartz.NewReal(),
gitSyncLogger,
)
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
if options.DeploymentValues.Prometheus.Enable { if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn) options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry) api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
@@ -1999,6 +2015,9 @@ type API struct {
dbRolluper *dbrollup.Rolluper dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats. // chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server chatDaemon *chatd.Server
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
} }
// Close waits for all WebSocket connections to drain before returning. // Close waits for all WebSocket connections to drain before returning.
@@ -2028,6 +2047,13 @@ func (api *API) Close() error {
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds") api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
} }
api.dbRolluper.Close() api.dbRolluper.Close()
// chatDiffWorker is unconditionally initialized in New().
select {
case <-api.gitSyncWorker.Done():
case <-time.After(10 * time.Second):
api.Logger.Warn(context.Background(),
"chat diff refresh worker did not exit in time")
}
if err := api.chatDaemon.Close(); err != nil { if err := api.chatDaemon.Close(); err != nil {
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err)) api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
} }
+21
View File
@@ -1539,6 +1539,17 @@ func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.Acquir
return q.db.AcquireProvisionerJob(ctx, arg) return q.db.AcquireProvisionerJob(ctx, arg)
} }
func (q *querier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
// This is a system-level batch operation used by the gitsync
// background worker. Per-object authorization is impractical
// for a SKIP LOCKED acquisition query; callers must use
// AsChatd context.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.AcquireStaleChatDiffStatuses(ctx, limitVal)
}
func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
fetch := func(ctx context.Context, arg database.ActivityBumpWorkspaceParams) (database.Workspace, error) { fetch := func(ctx context.Context, arg database.ActivityBumpWorkspaceParams) (database.Workspace, error) {
return q.db.GetWorkspaceByID(ctx, arg.WorkspaceID) return q.db.GetWorkspaceByID(ctx, arg.WorkspaceID)
@@ -1577,6 +1588,16 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
return q.db.ArchiveUnusedTemplateVersions(ctx, arg) return q.db.ArchiveUnusedTemplateVersions(ctx, arg)
} }
func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
// This is a system-level operation used by the gitsync
// background worker to reschedule failed refreshes. Same
// authorization pattern as AcquireStaleChatDiffStatuses.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.BackoffChatDiffStatus(ctx, arg)
}
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
// Could be any workspace agent and checking auth to each workspace agent is overkill for // Could be any workspace agent and checking auth to each workspace agent is overkill for
// the purpose of this function. // the purpose of this function.
+12
View File
@@ -770,6 +770,18 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes() dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus) check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
})) }))
s.Run("AcquireStaleChatDiffStatuses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), int32(10)).Return([]database.AcquireStaleChatDiffStatusesRow{}, nil).AnyTimes()
check.Args(int32(10)).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.AcquireStaleChatDiffStatusesRow{})
}))
s.Run("BackoffChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.BackoffChatDiffStatusParams{
ChatID: uuid.New(),
StaleAt: dbtime.Now(),
}
dbm.EXPECT().BackoffChatDiffStatus(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns()
}))
s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes() dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
+16
View File
@@ -136,6 +136,14 @@ func (m queryMetricsStore) AcquireProvisionerJob(ctx context.Context, arg databa
return r0, r1 return r0, r1
} }
func (m queryMetricsStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
start := time.Now()
r0, r1 := m.s.AcquireStaleChatDiffStatuses(ctx, limitVal)
m.queryLatencies.WithLabelValues("AcquireStaleChatDiffStatuses").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireStaleChatDiffStatuses").Inc()
return r0, r1
}
func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
start := time.Now() start := time.Now()
r0 := m.s.ActivityBumpWorkspace(ctx, arg) r0 := m.s.ActivityBumpWorkspace(ctx, arg)
@@ -168,6 +176,14 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar
return r0, r1 return r0, r1
} }
func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
start := time.Now()
r0 := m.s.BackoffChatDiffStatus(ctx, arg)
m.queryLatencies.WithLabelValues("BackoffChatDiffStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BackoffChatDiffStatus").Inc()
return r0
}
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
start := time.Now() start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg) r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
+29
View File
@@ -103,6 +103,21 @@ func (mr *MockStoreMockRecorder) AcquireProvisionerJob(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireProvisionerJob", reflect.TypeOf((*MockStore)(nil).AcquireProvisionerJob), ctx, arg) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireProvisionerJob", reflect.TypeOf((*MockStore)(nil).AcquireProvisionerJob), ctx, arg)
} }
// AcquireStaleChatDiffStatuses mocks base method.
func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal)
ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses.
func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal)
}
// ActivityBumpWorkspace mocks base method. // ActivityBumpWorkspace mocks base method.
func (m *MockStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error { func (m *MockStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -161,6 +176,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg)
} }
// BackoffChatDiffStatus mocks base method.
func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus.
func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
}
// BatchUpdateWorkspaceAgentMetadata mocks base method. // BatchUpdateWorkspaceAgentMetadata mocks base method.
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
+2
View File
@@ -39,6 +39,7 @@ type sqlcQuerier interface {
// multiple provisioners from acquiring the same jobs. See: // multiple provisioners from acquiring the same jobs. See:
// https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE // https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error)
// Bumps the workspace deadline by the template's configured "activity_bump" // Bumps the workspace deadline by the template's configured "activity_bump"
// duration (default 1h). If the workspace bump will cross an autostart // duration (default 1h). If the workspace bump will cross an autostart
// threshold, then the bump is autostart + TTL. This is the deadline behavior if // threshold, then the bump is autostart + TTL. This is the deadline behavior if
@@ -60,6 +61,7 @@ type sqlcQuerier interface {
// Only unused template versions will be archived, which are any versions not // Only unused template versions will be archived, which are any versions not
// referenced by the latest build of a workspace. // referenced by the latest build of a workspace.
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
+116
View File
@@ -3026,6 +3026,102 @@ func (q *sqlQuerier) AcquireChat(ctx context.Context, arg AcquireChatParams) (Ch
return i, err return i, err
} }
const acquireStaleChatDiffStatuses = `-- name: AcquireStaleChatDiffStatuses :many
WITH acquired AS (
UPDATE
chat_diff_statuses
SET
-- Claim for 5 minutes. The worker sets the real stale_at
-- after refresh. If the worker crashes, rows become eligible
-- again after this interval.
stale_at = NOW() + INTERVAL '5 minutes',
updated_at = NOW()
WHERE
chat_id IN (
SELECT
cds.chat_id
FROM
chat_diff_statuses cds
INNER JOIN
chats c ON c.id = cds.chat_id
WHERE
cds.stale_at <= NOW()
AND cds.git_remote_origin != ''
AND cds.git_branch != ''
AND c.archived = FALSE
ORDER BY
cds.stale_at ASC
FOR UPDATE OF cds
SKIP LOCKED
LIMIT
$1::int
)
RETURNING chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin
)
SELECT
acquired.chat_id, acquired.url, acquired.pull_request_state, acquired.changes_requested, acquired.additions, acquired.deletions, acquired.changed_files, acquired.refreshed_at, acquired.stale_at, acquired.created_at, acquired.updated_at, acquired.git_branch, acquired.git_remote_origin,
c.owner_id
FROM
acquired
INNER JOIN
chats c ON c.id = acquired.chat_id
`
type AcquireStaleChatDiffStatusesRow struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Url sql.NullString `db:"url" json:"url"`
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
Additions int32 `db:"additions" json:"additions"`
Deletions int32 `db:"deletions" json:"deletions"`
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
StaleAt time.Time `db:"stale_at" json:"stale_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
GitBranch string `db:"git_branch" json:"git_branch"`
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
}
func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error) {
rows, err := q.db.QueryContext(ctx, acquireStaleChatDiffStatuses, limitVal)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AcquireStaleChatDiffStatusesRow
for rows.Next() {
var i AcquireStaleChatDiffStatusesRow
if err := rows.Scan(
&i.ChatID,
&i.Url,
&i.PullRequestState,
&i.ChangesRequested,
&i.Additions,
&i.Deletions,
&i.ChangedFiles,
&i.RefreshedAt,
&i.StaleAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.GitBranch,
&i.GitRemoteOrigin,
&i.OwnerID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const archiveChatByID = `-- name: ArchiveChatByID :exec const archiveChatByID = `-- name: ArchiveChatByID :exec
UPDATE chats SET archived = true, updated_at = NOW() UPDATE chats SET archived = true, updated_at = NOW()
WHERE id = $1 OR root_chat_id = $1 WHERE id = $1 OR root_chat_id = $1
@@ -3036,6 +3132,26 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
return err return err
} }
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
UPDATE
chat_diff_statuses
SET
stale_at = $1::timestamptz,
updated_at = NOW()
WHERE
chat_id = $2::uuid
`
type BackoffChatDiffStatusParams struct {
StaleAt time.Time `db:"stale_at" json:"stale_at"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
}
func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error {
_, err := q.db.ExecContext(ctx, backoffChatDiffStatus, arg.StaleAt, arg.ChatID)
return err
}
const deleteAllChatQueuedMessages = `-- name: DeleteAllChatQueuedMessages :exec const deleteAllChatQueuedMessages = `-- name: DeleteAllChatQueuedMessages :exec
DELETE FROM chat_queued_messages WHERE chat_id = $1 DELETE FROM chat_queued_messages WHERE chat_id = $1
` `
+49
View File
@@ -448,3 +448,52 @@ LIMIT
-- name: GetChatByIDForUpdate :one -- name: GetChatByIDForUpdate :one
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE; SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
-- name: AcquireStaleChatDiffStatuses :many
WITH acquired AS (
UPDATE
chat_diff_statuses
SET
-- Claim for 5 minutes. The worker sets the real stale_at
-- after refresh. If the worker crashes, rows become eligible
-- again after this interval.
stale_at = NOW() + INTERVAL '5 minutes',
updated_at = NOW()
WHERE
chat_id IN (
SELECT
cds.chat_id
FROM
chat_diff_statuses cds
INNER JOIN
chats c ON c.id = cds.chat_id
WHERE
cds.stale_at <= NOW()
AND cds.git_remote_origin != ''
AND cds.git_branch != ''
AND c.archived = FALSE
ORDER BY
cds.stale_at ASC
FOR UPDATE OF cds
SKIP LOCKED
LIMIT
@limit_val::int
)
RETURNING *
)
SELECT
acquired.*,
c.owner_id
FROM
acquired
INNER JOIN
chats c ON c.id = acquired.chat_id;
-- name: BackoffChatDiffStatus :exec
UPDATE
chat_diff_statuses
SET
stale_at = @stale_at::timestamptz,
updated_at = NOW()
WHERE
chat_id = @chat_id::uuid;
+33 -3
View File
@@ -23,6 +23,7 @@ import (
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
@@ -82,6 +83,10 @@ type Config struct {
// a Git clone. e.g. "Username for 'https://github.com':" // a Git clone. e.g. "Username for 'https://github.com':"
// The regex would be `github\.com`.. // The regex would be `github\.com`..
Regex *regexp.Regexp Regex *regexp.Regexp
// APIBaseURL is the base URL for provider REST API calls
// (e.g., "https://api.github.com" for GitHub). Derived from
// defaults when not explicitly configured.
APIBaseURL string
// AppInstallURL is for GitHub App's (and hopefully others eventually) // AppInstallURL is for GitHub App's (and hopefully others eventually)
// to provide a link to install the app. There's installation // to provide a link to install the app. There's installation
// of the application, and user authentication. It's possible // of the application, and user authentication. It's possible
@@ -106,12 +111,23 @@ type Config struct {
CodeChallengeMethodsSupported []promoauth.Oauth2PKCEChallengeMethod CodeChallengeMethodsSupported []promoauth.Oauth2PKCEChallengeMethod
} }
// Git returns a Provider for this config if the provider type
// is a supported git hosting provider. Returns nil for non-git
// providers (e.g. Slack, JFrog).
func (c *Config) Git(client *http.Client) gitprovider.Provider {
norm := strings.ToLower(c.Type)
if !codersdk.EnhancedExternalAuthProvider(norm).Git() {
return nil
}
return gitprovider.New(norm, c.APIBaseURL, client)
}
// GenerateTokenExtra generates the extra token data to store in the database. // GenerateTokenExtra generates the extra token data to store in the database.
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) { func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
if len(c.ExtraTokenKeys) == 0 { if len(c.ExtraTokenKeys) == 0 {
return pqtype.NullRawMessage{}, nil return pqtype.NullRawMessage{}, nil
} }
extraMap := map[string]interface{}{} extraMap := map[string]any{}
for _, key := range c.ExtraTokenKeys { for _, key := range c.ExtraTokenKeys {
extraMap[key] = token.Extra(key) extraMap[key] = token.Extra(key)
} }
@@ -730,6 +746,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
ClientID: entry.ClientID, ClientID: entry.ClientID,
ClientSecret: entry.ClientSecret, ClientSecret: entry.ClientSecret,
Regex: regex, Regex: regex,
APIBaseURL: entry.APIBaseURL,
Type: entry.Type, Type: entry.Type,
NoRefresh: entry.NoRefresh, NoRefresh: entry.NoRefresh,
ValidateURL: entry.ValidateURL, ValidateURL: entry.ValidateURL,
@@ -766,7 +783,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
// applyDefaultsToConfig applies defaults to the config entry. // applyDefaultsToConfig applies defaults to the config entry.
func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) { func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
configType := codersdk.EnhancedExternalAuthProvider(config.Type) configType := codersdk.EnhancedExternalAuthProvider(strings.ToLower(config.Type))
if configType == "bitbucket" { if configType == "bitbucket" {
// For backwards compatibility, we need to support the "bitbucket" string. // For backwards compatibility, we need to support the "bitbucket" string.
configType = codersdk.EnhancedExternalAuthProviderBitBucketCloud configType = codersdk.EnhancedExternalAuthProviderBitBucketCloud
@@ -783,7 +800,7 @@ func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
} }
// Dynamic defaults // Dynamic defaults
switch codersdk.EnhancedExternalAuthProvider(config.Type) { switch configType {
case codersdk.EnhancedExternalAuthProviderGitHub: case codersdk.EnhancedExternalAuthProviderGitHub:
copyDefaultSettings(config, gitHubDefaults(config)) copyDefaultSettings(config, gitHubDefaults(config))
return return
@@ -864,6 +881,19 @@ func copyDefaultSettings(config *codersdk.ExternalAuthConfig, defaults codersdk.
if config.CodeChallengeMethodsSupported == nil { if config.CodeChallengeMethodsSupported == nil {
config.CodeChallengeMethodsSupported = []string{string(promoauth.PKCEChallengeMethodSha256)} config.CodeChallengeMethodsSupported = []string{string(promoauth.PKCEChallengeMethodSha256)}
} }
// Set default API base URL for providers that need one.
if config.APIBaseURL == "" {
normType := strings.ToLower(config.Type)
switch codersdk.EnhancedExternalAuthProvider(normType) {
case codersdk.EnhancedExternalAuthProviderGitHub:
config.APIBaseURL = "https://api.github.com"
case codersdk.EnhancedExternalAuthProviderGitLab:
config.APIBaseURL = "https://gitlab.com/api/v4"
case codersdk.EnhancedExternalAuthProviderGitea:
config.APIBaseURL = "https://gitea.com/api/v1"
}
}
} }
// gitHubDefaults returns default config values for GitHub. // gitHubDefaults returns default config values for GitHub.
@@ -25,6 +25,7 @@ func TestGitlabDefaults(t *testing.T) {
DisplayName: "GitLab", DisplayName: "GitLab",
DisplayIcon: "/icon/gitlab.svg", DisplayIcon: "/icon/gitlab.svg",
Regex: `^(https?://)?gitlab\.com(/.*)?$`, Regex: `^(https?://)?gitlab\.com(/.*)?$`,
APIBaseURL: "https://gitlab.com/api/v4",
Scopes: []string{"write_repository"}, Scopes: []string{"write_repository"},
CodeChallengeMethodsSupported: []string{string(promoauth.PKCEChallengeMethodSha256)}, CodeChallengeMethodsSupported: []string{string(promoauth.PKCEChallengeMethodSha256)},
} }
+34
View File
@@ -845,6 +845,40 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
return fake, config, link return fake, config, link
} }
func TestApplyDefaultsToConfig_CaseInsensitive(t *testing.T) {
t.Parallel()
instrument := promoauth.NewFactory(prometheus.NewRegistry())
accessURL, err := url.Parse("https://coder.example.com")
require.NoError(t, err)
for _, tc := range []struct {
Name string
Type string
}{
{Name: "GitHub", Type: "GitHub"},
{Name: "GITLAB", Type: "GITLAB"},
{Name: "Gitea", Type: "Gitea"},
} {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
configs, err := externalauth.ConvertConfig(
instrument,
[]codersdk.ExternalAuthConfig{{
Type: tc.Type,
ClientID: "test-id",
ClientSecret: "test-secret",
}},
accessURL,
)
require.NoError(t, err)
require.Len(t, configs, 1)
// Defaults should have been applied despite mixed-case Type.
assert.NotEmpty(t, configs[0].AuthCodeURL("state"), "auth URL should be populated from defaults")
})
}
}
type roundTripper func(req *http.Request) (*http.Response, error) type roundTripper func(req *http.Request) (*http.Response, error)
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+540
View File
@@ -0,0 +1,540 @@
package gitprovider
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/quartz"
)
const (
defaultGitHubAPIBaseURL = "https://api.github.com"
// Adding padding to our retry times to guard against over-consumption of request quotas.
RateLimitPadding = 5 * time.Minute
)
type githubProvider struct {
apiBaseURL string
webBaseURL string
httpClient *http.Client
clock quartz.Clock
// Compiled per-instance to support GitHub Enterprise hosts.
pullRequestPathPattern *regexp.Regexp
repositoryHTTPSPattern *regexp.Regexp
repositorySSHPathPattern *regexp.Regexp
}
func newGitHub(apiBaseURL string, httpClient *http.Client, clock quartz.Clock) *githubProvider {
if apiBaseURL == "" {
apiBaseURL = defaultGitHubAPIBaseURL
}
apiBaseURL = strings.TrimRight(apiBaseURL, "/")
if httpClient == nil {
httpClient = http.DefaultClient
}
// Derive the web base URL from the API base URL.
// github.com: api.github.com → github.com
// GHE: ghes.corp.com/api/v3 → ghes.corp.com
webBaseURL := deriveWebBaseURL(apiBaseURL)
// Parse the host for regex construction.
host := extractHost(webBaseURL)
// Escape the host for use in regex patterns.
escapedHost := regexp.QuoteMeta(host)
return &githubProvider{
apiBaseURL: apiBaseURL,
webBaseURL: webBaseURL,
httpClient: httpClient,
clock: clock,
pullRequestPathPattern: regexp.MustCompile(
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
),
repositoryHTTPSPattern: regexp.MustCompile(
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
),
repositorySSHPathPattern: regexp.MustCompile(
`^(?:ssh://)?git@` + escapedHost + `[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
),
}
}
// deriveWebBaseURL converts a GitHub API base URL to the
// corresponding web base URL.
//
// github.com: https://api.github.com → https://github.com
// GHE: https://ghes.corp.com/api/v3 → https://ghes.corp.com
func deriveWebBaseURL(apiBaseURL string) string {
u, err := url.Parse(apiBaseURL)
if err != nil {
return "https://github.com"
}
// Standard github.com: API host is api.github.com.
if strings.EqualFold(u.Host, "api.github.com") {
return "https://github.com"
}
// GHE: strip /api/v3 path suffix.
u.Path = strings.TrimSuffix(u.Path, "/api/v3")
u.Path = strings.TrimSuffix(u.Path, "/")
return u.String()
}
// extractHost returns the host portion of a URL.
func extractHost(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return "github.com"
}
return u.Host
}
func (g *githubProvider) ParseRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", "", "", false
}
matches := g.repositoryHTTPSPattern.FindStringSubmatch(raw)
if len(matches) != 3 {
matches = g.repositorySSHPathPattern.FindStringSubmatch(raw)
}
if len(matches) != 3 {
return "", "", "", false
}
owner = strings.TrimSpace(matches[1])
repo = strings.TrimSpace(matches[2])
repo = strings.TrimSuffix(repo, ".git")
if owner == "" || repo == "" {
return "", "", "", false
}
return owner, repo, fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo)), true
}
func (g *githubProvider) ParsePullRequestURL(raw string) (PRRef, bool) {
matches := g.pullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
if len(matches) != 4 {
return PRRef{}, false
}
number, err := strconv.Atoi(matches[3])
if err != nil {
return PRRef{}, false
}
return PRRef{
Owner: matches[1],
Repo: matches[2],
Number: number,
}, true
}
func (g *githubProvider) NormalizePullRequestURL(raw string) string {
ref, ok := g.ParsePullRequestURL(strings.TrimRight(
strings.TrimSpace(raw),
"),.;",
))
if !ok {
return ""
}
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
}
// escapePathPreserveSlashes escapes each segment of a path
// individually, preserving `/` separators. This is needed for
// web URLs where GitHub expects literal slashes (e.g.
// /tree/feat/new-thing).
func escapePathPreserveSlashes(s string) string {
segments := strings.Split(s, "/")
for i, seg := range segments {
segments[i] = url.PathEscape(seg)
}
return strings.Join(segments, "/")
}
func (g *githubProvider) BuildBranchURL(owner string, repo string, branch string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
branch = strings.TrimSpace(branch)
if owner == "" || repo == "" || branch == "" {
return ""
}
return fmt.Sprintf(
"%s/%s/%s/tree/%s",
g.webBaseURL,
url.PathEscape(owner),
url.PathEscape(repo),
escapePathPreserveSlashes(branch),
)
}
func (g *githubProvider) BuildRepositoryURL(owner string, repo string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
if owner == "" || repo == "" {
return ""
}
return fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo))
}
func (g *githubProvider) BuildPullRequestURL(ref PRRef) string {
if ref.Owner == "" || ref.Repo == "" || ref.Number <= 0 {
return ""
}
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
}
func (g *githubProvider) ResolveBranchPullRequest(
ctx context.Context,
token string,
ref BranchRef,
) (*PRRef, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return nil, nil
}
query := url.Values{}
query.Set("state", "open")
query.Set("head", fmt.Sprintf("%s:%s", ref.Owner, ref.Branch))
query.Set("sort", "updated")
query.Set("direction", "desc")
query.Set("per_page", "1")
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls?%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
query.Encode(),
)
var pulls []struct {
HTMLURL string `json:"html_url"`
Number int `json:"number"`
}
if err := g.decodeJSON(ctx, requestURL, token, &pulls); err != nil {
return nil, err
}
if len(pulls) == 0 {
return nil, nil
}
prRef, ok := g.ParsePullRequestURL(pulls[0].HTMLURL)
if !ok {
return nil, nil
}
return &prRef, nil
}
func (g *githubProvider) FetchPullRequestStatus(
ctx context.Context,
token string,
ref PRRef,
) (*PRStatus, error) {
pullEndpoint := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
ref.Number,
)
var pull struct {
State string `json:"state"`
Merged bool `json:"merged"`
Draft bool `json:"draft"`
Additions int32 `json:"additions"`
Deletions int32 `json:"deletions"`
ChangedFiles int32 `json:"changed_files"`
Head struct {
SHA string `json:"sha"`
} `json:"head"`
}
if err := g.decodeJSON(ctx, pullEndpoint, token, &pull); err != nil {
return nil, err
}
var reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
// GitHub returns at most 100 reviews per page. We do not
// paginate because PRs with >100 reviews are extremely rare,
// and the cost of multiple API calls per refresh is not
// justified. If needed, pagination can be added later.
if err := g.decodeJSON(
ctx,
pullEndpoint+"/reviews?per_page=100",
token,
&reviews,
); err != nil {
return nil, err
}
state := PRState(strings.ToLower(strings.TrimSpace(pull.State)))
if pull.Merged {
state = PRStateMerged
}
return &PRStatus{
State: state,
Draft: pull.Draft,
HeadSHA: pull.Head.SHA,
DiffStats: DiffStats{
Additions: pull.Additions,
Deletions: pull.Deletions,
ChangedFiles: pull.ChangedFiles,
},
ChangesRequested: hasOutstandingChangesRequested(reviews),
FetchedAt: g.clock.Now().UTC(),
}, nil
}
func (g *githubProvider) FetchPullRequestDiff(
ctx context.Context,
token string,
ref PRRef,
) (string, error) {
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
ref.Number,
)
return g.fetchDiff(ctx, requestURL, token)
}
func (g *githubProvider) FetchBranchDiff(
ctx context.Context,
token string,
ref BranchRef,
) (string, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return "", nil
}
var repository struct {
DefaultBranch string `json:"default_branch"`
}
repositoryURL := fmt.Sprintf(
"%s/repos/%s/%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
)
if err := g.decodeJSON(ctx, repositoryURL, token, &repository); err != nil {
return "", err
}
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
if defaultBranch == "" {
return "", xerrors.New("github repository default branch is empty")
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/compare/%s...%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
url.PathEscape(defaultBranch),
url.PathEscape(ref.Branch),
)
return g.fetchDiff(ctx, requestURL, token)
}
func (g *githubProvider) decodeJSON(
ctx context.Context,
requestURL string,
token string,
dest any,
) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return xerrors.Errorf("create github request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff-status")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.httpClient.Do(req)
if err != nil {
return xerrors.Errorf("execute github request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
retryAfter := ParseRetryAfter(resp.Header, g.clock)
if retryAfter > 0 {
return &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter + RateLimitPadding)}
}
// No rate-limit headers — fall through to generic error.
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return xerrors.Errorf(
"github request failed with status %d",
resp.StatusCode,
)
}
return xerrors.Errorf(
"github request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
return xerrors.Errorf("decode github response: %w", err)
}
return nil
}
func (g *githubProvider) fetchDiff(
ctx context.Context,
requestURL string,
token string,
) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return "", xerrors.Errorf("create github diff request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.diff")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.httpClient.Do(req)
if err != nil {
return "", xerrors.Errorf("execute github diff request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
retryAfter := ParseRetryAfter(resp.Header, g.clock)
if retryAfter > 0 {
return "", &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter + RateLimitPadding)}
}
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
}
return "", xerrors.Errorf(
"github diff request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
// Read one extra byte beyond MaxDiffSize so we can detect
// whether the diff exceeds the limit. LimitReader stops us
// allocating an arbitrarily large buffer by accident.
buf, err := io.ReadAll(io.LimitReader(resp.Body, MaxDiffSize+1))
if err != nil {
return "", xerrors.Errorf("read github diff response: %w", err)
}
if len(buf) > MaxDiffSize {
return "", ErrDiffTooLarge
}
return string(buf), nil
}
// ParseRetryAfter extracts a retry-after time from GitHub
// rate-limit headers. Returns zero value if no recognizable header is
// present.
func ParseRetryAfter(h http.Header, clk quartz.Clock) time.Duration {
if clk == nil {
clk = quartz.NewReal()
}
// Retry-After header: seconds until retry.
if ra := h.Get("Retry-After"); ra != "" {
if secs, err := strconv.Atoi(ra); err == nil {
return time.Duration(secs) * time.Second
}
}
// X-Ratelimit-Reset header: unix timestamp. We compute the
// duration from now according to the caller's clock.
if reset := h.Get("X-Ratelimit-Reset"); reset != "" {
if ts, err := strconv.ParseInt(reset, 10, 64); err == nil {
d := time.Unix(ts, 0).Sub(clk.Now())
return d
}
}
return 0
}
func hasOutstandingChangesRequested(
reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
},
) bool {
type reviewerState struct {
reviewID int64
state string
}
statesByReviewer := make(map[string]reviewerState)
for _, review := range reviews {
login := strings.ToLower(strings.TrimSpace(review.User.Login))
if login == "" {
continue
}
state := strings.ToUpper(strings.TrimSpace(review.State))
switch state {
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
default:
continue
}
current, exists := statesByReviewer[login]
if exists && current.reviewID > review.ID {
continue
}
statesByReviewer[login] = reviewerState{
reviewID: review.ID,
state: state,
}
}
for _, state := range statesByReviewer {
if state.state == "CHANGES_REQUESTED" {
return true
}
}
return false
}
@@ -0,0 +1,994 @@
package gitprovider_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/quartz"
)
func TestGitHubParseRepositoryOrigin(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expectOK bool
expectOwner string
expectRepo string
expectNormalized string
}{
{
name: "HTTPS URL",
raw: "https://github.com/coder/coder",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "HTTPS URL with .git",
raw: "https://github.com/coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "HTTPS URL with trailing slash",
raw: "https://github.com/coder/coder/",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL",
raw: "git@github.com:coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL without .git",
raw: "git@github.com:coder/coder",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL with ssh:// prefix",
raw: "ssh://git@github.com/coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "GitLab URL does not match",
raw: "https://gitlab.com/coder/coder",
expectOK: false,
},
{
name: "Empty string",
raw: "",
expectOK: false,
},
{
name: "Not a URL",
raw: "not-a-url",
expectOK: false,
},
{
name: "Hyphenated owner and repo",
raw: "https://github.com/my-org/my-repo.git",
expectOK: true,
expectOwner: "my-org",
expectRepo: "my-repo",
expectNormalized: "https://github.com/my-org/my-repo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin(tt.raw)
assert.Equal(t, tt.expectOK, ok)
if tt.expectOK {
assert.Equal(t, tt.expectOwner, owner)
assert.Equal(t, tt.expectRepo, repo)
assert.Equal(t, tt.expectNormalized, normalized)
}
})
}
}
func TestGitHubParsePullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expectOK bool
expectOwner string
expectRepo string
expectNumber int
}{
{
name: "Standard PR URL",
raw: "https://github.com/coder/coder/pull/123",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 123,
},
{
name: "PR URL with query string",
raw: "https://github.com/coder/coder/pull/456?diff=split",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 456,
},
{
name: "PR URL with fragment",
raw: "https://github.com/coder/coder/pull/789#discussion",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 789,
},
{
name: "Not a PR URL",
raw: "https://github.com/coder/coder",
expectOK: false,
},
{
name: "Issue URL (not PR)",
raw: "https://github.com/coder/coder/issues/123",
expectOK: false,
},
{
name: "GitLab MR URL",
raw: "https://gitlab.com/coder/coder/-/merge_requests/123",
expectOK: false,
},
{
name: "Empty string",
raw: "",
expectOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ref, ok := gp.ParsePullRequestURL(tt.raw)
assert.Equal(t, tt.expectOK, ok)
if tt.expectOK {
assert.Equal(t, tt.expectOwner, ref.Owner)
assert.Equal(t, tt.expectRepo, ref.Repo)
assert.Equal(t, tt.expectNumber, ref.Number)
}
})
}
}
func TestGitHubNormalizePullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expected string
}{
{
name: "Already normalized",
raw: "https://github.com/coder/coder/pull/123",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With trailing punctuation",
raw: "https://github.com/coder/coder/pull/123).",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With query string",
raw: "https://github.com/coder/coder/pull/123?diff=split",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With whitespace",
raw: " https://github.com/coder/coder/pull/123 ",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "Not a PR URL",
raw: "https://example.com",
expected: "",
},
{
name: "Empty string",
raw: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.NormalizePullRequestURL(tt.raw)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubBuildBranchURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
owner string
repo string
branch string
expected string
}{
{
name: "Simple branch",
owner: "coder",
repo: "coder",
branch: "main",
expected: "https://github.com/coder/coder/tree/main",
},
{
name: "Branch with slash",
owner: "coder",
repo: "coder",
branch: "feat/new-thing",
expected: "https://github.com/coder/coder/tree/feat/new-thing",
},
{
name: "Empty owner",
owner: "",
repo: "coder",
branch: "main",
expected: "",
},
{
name: "Empty repo",
owner: "coder",
repo: "",
branch: "main",
expected: "",
},
{
name: "Empty branch",
owner: "coder",
repo: "coder",
branch: "",
expected: "",
},
{
name: "Branch with slashes",
owner: "my-org",
repo: "my-repo",
branch: "feat/new-thing",
expected: "https://github.com/my-org/my-repo/tree/feat/new-thing",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.BuildBranchURL(tt.owner, tt.repo, tt.branch)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubBuildPullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
ref gitprovider.PRRef
expected string
}{
{
name: "Valid PR ref",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 123},
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "Empty owner",
ref: gitprovider.PRRef{Owner: "", Repo: "coder", Number: 123},
expected: "",
},
{
name: "Empty repo",
ref: gitprovider.PRRef{Owner: "coder", Repo: "", Number: 123},
expected: "",
},
{
name: "Zero number",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 0},
expected: "",
},
{
name: "Negative number",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: -1},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.BuildPullRequestURL(tt.ref)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubEnterpriseURLs(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "https://ghes.corp.com/api/v3", nil)
require.NotNil(t, gp)
t.Run("ParseRepositoryOrigin HTTPS", func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("https://ghes.corp.com/org/repo.git")
assert.True(t, ok)
assert.Equal(t, "org", owner)
assert.Equal(t, "repo", repo)
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
})
t.Run("ParseRepositoryOrigin SSH", func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("git@ghes.corp.com:org/repo.git")
assert.True(t, ok)
assert.Equal(t, "org", owner)
assert.Equal(t, "repo", repo)
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
})
t.Run("ParsePullRequestURL", func(t *testing.T) {
t.Parallel()
ref, ok := gp.ParsePullRequestURL("https://ghes.corp.com/org/repo/pull/42")
assert.True(t, ok)
assert.Equal(t, "org", ref.Owner)
assert.Equal(t, "repo", ref.Repo)
assert.Equal(t, 42, ref.Number)
})
t.Run("NormalizePullRequestURL", func(t *testing.T) {
t.Parallel()
result := gp.NormalizePullRequestURL("https://ghes.corp.com/org/repo/pull/42?x=y")
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
})
t.Run("BuildBranchURL", func(t *testing.T) {
t.Parallel()
result := gp.BuildBranchURL("org", "repo", "main")
assert.Equal(t, "https://ghes.corp.com/org/repo/tree/main", result)
})
t.Run("BuildPullRequestURL", func(t *testing.T) {
t.Parallel()
result := gp.BuildPullRequestURL(gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42})
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
})
t.Run("github.com URLs do not match GHE instance", func(t *testing.T) {
t.Parallel()
_, _, _, ok := gp.ParseRepositoryOrigin("https://github.com/coder/coder")
assert.False(t, ok, "github.com HTTPS URL should not match GHE instance")
_, _, _, ok = gp.ParseRepositoryOrigin("git@github.com:coder/coder.git")
assert.False(t, ok, "github.com SSH URL should not match GHE instance")
_, ok = gp.ParsePullRequestURL("https://github.com/coder/coder/pull/123")
assert.False(t, ok, "github.com PR URL should not match GHE instance")
})
}
func TestNewUnsupportedProvider(t *testing.T) {
t.Parallel()
gp := gitprovider.New("unsupported", "", nil)
assert.Nil(t, gp, "unsupported provider type should return nil")
}
func TestGitHubRatelimit_403WithResetHeader(t *testing.T) {
t.Parallel()
resetTime := time.Now().Add(60 * time.Second)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("X-Ratelimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"message": "API rate limit exceeded"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
assert.WithinDuration(t, resetTime.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 2*time.Second)
}
func TestGitHubRatelimit_429WithRetryAfter(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Retry-After", "120")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "secondary rate limit"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
// Retry-After: 120 means ~120s from now.
expected := time.Now().Add(120 * time.Second)
assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second)
}
func TestGitHubRatelimit_403NormalError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"message": "Bad credentials"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"bad-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
assert.False(t, errors.As(err, &rlErr), "error should NOT be *RateLimitError")
assert.Contains(t, err.Error(), "403")
}
func TestGitHubFetchPullRequestDiff(t *testing.T) {
t.Parallel()
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
t.Run("OK", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(smallDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Equal(t, smallDiff, diff)
})
t.Run("ExactlyMaxSize", func(t *testing.T) {
t.Parallel()
exactDiff := string(make([]byte, gitprovider.MaxDiffSize))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(exactDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Len(t, diff, gitprovider.MaxDiffSize)
})
t.Run("TooLarge", func(t *testing.T) {
t.Parallel()
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(oversizeDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
})
}
func TestFetchPullRequestDiff_Ratelimit(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
expected := time.Now().Add(60 * time.Second)
assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second)
}
func TestFetchBranchDiff_Ratelimit(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
// Second request: compare endpoint returns 429.
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
return
}
// First request: repo metadata.
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
expected := time.Now().Add(60 * time.Second)
assert.WithinDuration(t, expected.Add(gitprovider.RateLimitPadding), rlErr.RetryAfter, 5*time.Second)
}
func TestFetchPullRequestStatus(t *testing.T) {
t.Parallel()
type review struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
makeReview := func(id int64, state, login string) review {
r := review{ID: id, State: state}
r.User.Login = login
return r
}
tests := []struct {
name string
pullJSON string
reviews []review
expectedState gitprovider.PRState
expectedDraft bool
changesRequested bool
}{
{
name: "OpenPR/NoReviews",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateOpen,
expectedDraft: false,
changesRequested: false,
},
{
name: "OpenPR/SingleChangesRequested",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{makeReview(1, "CHANGES_REQUESTED", "alice")},
expectedState: gitprovider.PRStateOpen,
changesRequested: true,
},
{
name: "OpenPR/ChangesRequestedThenApproved",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "CHANGES_REQUESTED", "alice"),
makeReview(2, "APPROVED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "OpenPR/ChangesRequestedThenDismissed",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "CHANGES_REQUESTED", "alice"),
makeReview(2, "DISMISSED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "OpenPR/MultipleReviewersMixed",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "APPROVED", "alice"),
makeReview(2, "CHANGES_REQUESTED", "bob"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: true,
},
{
name: "OpenPR/CommentedDoesNotAffect",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "COMMENTED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "MergedPR",
pullJSON: `{"state":"closed","merged":true,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateMerged,
changesRequested: false,
},
{
name: "DraftPR",
pullJSON: `{"state":"open","merged":false,"draft":true,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateOpen,
expectedDraft: true,
changesRequested: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
reviewsJSON, err := json.Marshal(tc.reviews)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1/reviews", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(reviewsJSON)
})
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tc.pullJSON))
})
srv := httptest.NewServer(mux)
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
before := time.Now().UTC()
status, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Equal(t, tc.expectedState, status.State)
assert.Equal(t, tc.expectedDraft, status.Draft)
assert.Equal(t, tc.changesRequested, status.ChangesRequested)
assert.Equal(t, "abc123", status.HeadSHA)
assert.Equal(t, int32(10), status.DiffStats.Additions)
assert.Equal(t, int32(5), status.DiffStats.Deletions)
assert.Equal(t, int32(3), status.DiffStats.ChangedFiles)
assert.False(t, status.FetchedAt.IsZero())
assert.True(t, !status.FetchedAt.Before(before), "FetchedAt should be >= test start time")
})
}
}
func TestResolveBranchPullRequest(t *testing.T) {
t.Parallel()
t.Run("Found", func(t *testing.T) {
t.Parallel()
var srvURL string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters.
assert.Equal(t, "open", r.URL.Query().Get("state"))
assert.Equal(t, "owner:feat", r.URL.Query().Get("head"))
w.Header().Set("Content-Type", "application/json")
// Use the test server's URL so ParsePullRequestURL
// matches the provider's derived web host.
htmlURL := fmt.Sprintf("https://%s/owner/repo/pull/42",
strings.TrimPrefix(strings.TrimPrefix(srvURL, "http://"), "https://"))
_, _ = w.Write([]byte(fmt.Sprintf(`[{"html_url":%q,"number":42}]`, htmlURL)))
}))
defer srv.Close()
srvURL = srv.URL
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
require.NotNil(t, prRef)
assert.Equal(t, "owner", prRef.Owner)
assert.Equal(t, "repo", prRef.Repo)
assert.Equal(t, 42, prRef.Number)
})
t.Run("NoneOpen", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[]`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Nil(t, prRef)
})
t.Run("InvalidHTMLURL", func(t *testing.T) {
t.Parallel()
// If html_url can't be parsed as a PR URL, ResolveBranchPullRequest
// returns nil, nil.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"html_url":"not-a-valid-url","number":42}]`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Nil(t, prRef)
})
}
func TestFetchBranchDiff(t *testing.T) {
t.Parallel()
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
t.Run("OK", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(smallDiff))
return
}
// Repo metadata.
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Equal(t, smallDiff, diff)
})
t.Run("EmptyDefaultBranch", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":""}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.Error(t, err)
assert.Contains(t, err.Error(), "default branch is empty")
})
t.Run("DiffTooLarge", func(t *testing.T) {
t.Parallel()
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(oversizeDiff))
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
})
}
func TestEscapePathPreserveSlashes(t *testing.T) {
t.Parallel()
// The function is unexported, so test it indirectly via BuildBranchURL.
// A branch with a space in a segment should be escaped, but slashes preserved.
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
got := gp.BuildBranchURL("owner", "repo", "feat/my thing")
assert.Equal(t, "https://github.com/owner/repo/tree/feat/my%20thing", got)
}
func TestParseRetryAfter(t *testing.T) {
t.Parallel()
clk := quartz.NewMock(t)
clk.Set(time.Now())
t.Run("RetryAfterSeconds", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "120")
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, 120*time.Second, d)
})
t.Run("XRatelimitReset", func(t *testing.T) {
t.Parallel()
future := clk.Now().Add(90 * time.Second)
t.Logf("now: %d future: %d", clk.Now().Unix(), future.Unix())
h := http.Header{}
h.Set("X-Ratelimit-Reset", strconv.FormatInt(future.Unix(), 10))
d := gitprovider.ParseRetryAfter(h, clk)
assert.WithinDuration(t, future, clk.Now().Add(d), time.Second)
})
t.Run("NoHeaders", func(t *testing.T) {
t.Parallel()
h := http.Header{}
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, time.Duration(0), d)
})
t.Run("InvalidValue", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "not-a-number")
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, time.Duration(0), d)
})
t.Run("RetryAfterTakesPrecedence", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "60")
h.Set("X-Ratelimit-Reset", strconv.FormatInt(
clk.Now().Unix()+120, 10,
))
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, 60*time.Second, d)
})
}
@@ -0,0 +1,179 @@
package gitprovider
import (
"context"
"fmt"
"net/http"
"time"
"golang.org/x/xerrors"
"github.com/coder/quartz"
)
// providerOptions holds optional configuration for provider
// construction.
type providerOptions struct {
clock quartz.Clock
}
// Option configures optional behavior for a Provider.
type Option func(*providerOptions)
// WithClock sets the clock used by the provider. Defaults to
// quartz.NewReal() if not provided.
func WithClock(c quartz.Clock) Option {
return func(o *providerOptions) {
o.clock = c
}
}
// PRState is the normalized state of a pull/merge request across
// all providers.
type PRState string
const (
PRStateOpen PRState = "open"
PRStateClosed PRState = "closed"
PRStateMerged PRState = "merged"
)
// PRRef identifies a pull request on any provider.
type PRRef struct {
// Owner is the repository owner / project / workspace.
Owner string
// Repo is the repository name or slug.
Repo string
// Number is the PR number / IID / index.
Number int
}
// BranchRef identifies a branch in a repository, used for
// branch-to-PR resolution.
type BranchRef struct {
Owner string
Repo string
Branch string
}
// DiffStats summarizes the size of a PR's changes.
type DiffStats struct {
Additions int32
Deletions int32
ChangedFiles int32
}
// PRStatus is the complete status of a pull/merge request.
// This is the universal return type that all providers populate.
type PRStatus struct {
// State is the PR's lifecycle state.
State PRState
// Draft indicates the PR is marked as draft/WIP.
Draft bool
// HeadSHA is the SHA of the head commit.
HeadSHA string
// DiffStats summarizes additions/deletions/files changed.
DiffStats DiffStats
// ChangesRequested is a convenience boolean: true if any
// reviewer's current state is "changes_requested".
ChangesRequested bool
// FetchedAt is when this status was fetched.
FetchedAt time.Time
}
// MaxDiffSize is the maximum number of bytes read from a diff
// response. Diffs exceeding this limit are rejected with
// ErrDiffTooLarge.
const MaxDiffSize = 4 << 20 // 4 MiB
// ErrDiffTooLarge is returned when a diff exceeds MaxDiffSize.
var ErrDiffTooLarge = xerrors.Errorf("diff exceeds maximum size of %d bytes", MaxDiffSize)
// Provider defines the interface that all Git hosting providers
// implement. Each method is designed to minimize API round-trips
// for the specific provider.
type Provider interface {
// FetchPullRequestStatus retrieves the complete status of a
// pull request in the minimum number of API calls for this
// provider.
FetchPullRequestStatus(ctx context.Context, token string, ref PRRef) (*PRStatus, error)
// ResolveBranchPullRequest finds the open PR (if any) for
// the given branch. Returns nil, nil if no open PR exists.
ResolveBranchPullRequest(ctx context.Context, token string, ref BranchRef) (*PRRef, error)
// FetchPullRequestDiff returns the raw unified diff for a
// pull request. This uses the PR's actual base branch (which
// may differ from the repo default branch, e.g. a PR
// targeting "staging" instead of "main"), so it matches what
// the provider shows on the PR's "Files changed" tab.
// Returns ErrDiffTooLarge if the diff exceeds MaxDiffSize.
FetchPullRequestDiff(ctx context.Context, token string, ref PRRef) (string, error)
// FetchBranchDiff returns the diff of a branch compared
// against the repository's default branch. This is the
// fallback when no pull request exists yet (e.g. the agent
// pushed a branch but hasn't opened a PR). Returns
// ErrDiffTooLarge if the diff exceeds MaxDiffSize.
FetchBranchDiff(ctx context.Context, token string, ref BranchRef) (string, error)
// ParseRepositoryOrigin parses a remote origin URL (HTTPS
// or SSH) into owner and repo components, returning the
// normalized HTTPS URL. Returns false if the URL does not
// match this provider.
ParseRepositoryOrigin(raw string) (owner, repo, normalizedOrigin string, ok bool)
// ParsePullRequestURL parses a pull request URL into a
// PRRef. Returns false if the URL does not match this
// provider.
ParsePullRequestURL(raw string) (PRRef, bool)
// NormalizePullRequestURL normalizes a pull request URL,
// stripping trailing punctuation, query strings, and
// fragments. Returns empty string if the URL does not
// match this provider.
NormalizePullRequestURL(raw string) string
// BuildBranchURL constructs a URL to view a branch on
// the provider's web UI.
BuildBranchURL(owner, repo, branch string) string
// BuildRepositoryURL constructs a URL to view a repository
// on the provider's web UI.
BuildRepositoryURL(owner, repo string) string
// BuildPullRequestURL constructs a URL to view a pull
// request on the provider's web UI.
BuildPullRequestURL(ref PRRef) string
}
// New creates a Provider for the given provider type and API base
// URL. Returns nil if the provider type is not a supported git
// provider.
func New(providerType string, apiBaseURL string, httpClient *http.Client, opts ...Option) Provider {
o := providerOptions{}
for _, opt := range opts {
opt(&o)
}
if o.clock == nil {
o.clock = quartz.NewReal()
}
switch providerType {
case "github":
return newGitHub(apiBaseURL, httpClient, o.clock)
default:
// Other providers (gitlab, bitbucket-cloud, etc.) will be
// added here as they are implemented.
return nil
}
}
// RateLimitError indicates the git provider's API rate limit was hit.
type RateLimitError struct {
RetryAfter time.Time
}
func (e *RateLimitError) Error() string {
return fmt.Sprintf("rate limited until %s", e.RetryAfter.Format(time.RFC3339))
}
+230
View File
@@ -0,0 +1,230 @@
package gitsync
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/quartz"
)
const (
// DiffStatusTTL is how long a successfully refreshed
// diff status remains fresh before becoming stale again.
DiffStatusTTL = 120 * time.Second
)
// ProviderResolver maps a git remote origin to the gitprovider
// that handles it. Returns nil if no provider matches.
type ProviderResolver func(origin string) gitprovider.Provider
var ErrNoTokenAvailable error = errors.New("no token available")
// TokenResolver obtains the user's git access token for a given
// remote origin. Should return nil if no token is available, in
// which case ErrNoTokenAvailable will be returned.
type TokenResolver func(
ctx context.Context,
userID uuid.UUID,
origin string,
) (*string, error)
// Refresher contains the stateless business logic for fetching
// fresh PR data from a git provider given a stale
// database.ChatDiffStatus row.
type Refresher struct {
providers ProviderResolver
tokens TokenResolver
logger slog.Logger
clock quartz.Clock
}
// NewRefresher creates a Refresher with the given dependency
// functions.
func NewRefresher(
providers ProviderResolver,
tokens TokenResolver,
logger slog.Logger,
clock quartz.Clock,
) *Refresher {
return &Refresher{
providers: providers,
tokens: tokens,
logger: logger,
clock: clock,
}
}
// RefreshRequest pairs a stale row with the chat owner who
// holds the git token needed for API calls.
type RefreshRequest struct {
Row database.ChatDiffStatus
OwnerID uuid.UUID
}
// RefreshResult is the outcome for a single row.
// - Params != nil, Error == nil → success, caller should upsert.
// - Params == nil, Error == nil → no PR yet, caller should skip.
// - Params == nil, Error != nil → row-level failure.
type RefreshResult struct {
Request RefreshRequest
Params *database.UpsertChatDiffStatusParams
Error error
}
// groupKey identifies a unique (owner, origin) pair so that
// provider and token resolution happen once per group.
type groupKey struct {
ownerID uuid.UUID
origin string
}
// Refresh fetches fresh PR data for a batch of stale rows.
// Rows are grouped internally by (ownerID, origin) so that
// provider and token resolution happen once per group. A
// top-level error is returned only when the entire batch
// fails catastrophically. Per-row outcomes are in the
// returned RefreshResult slice (one per input request, same
// order).
func (r *Refresher) Refresh(
ctx context.Context,
requests []RefreshRequest,
) ([]RefreshResult, error) {
results := make([]RefreshResult, len(requests))
for i, req := range requests {
results[i].Request = req
}
// Group request indices by (ownerID, origin).
groups := make(map[groupKey][]int)
for i, req := range requests {
key := groupKey{
ownerID: req.OwnerID,
origin: req.Row.GitRemoteOrigin,
}
groups[key] = append(groups[key], i)
}
for key, indices := range groups {
provider := r.providers(key.origin)
if provider == nil {
err := xerrors.Errorf("no provider for origin %q", key.origin)
for _, i := range indices {
results[i].Error = err
}
continue
}
token, err := r.tokens(ctx, key.ownerID, key.origin)
if err != nil {
err = xerrors.Errorf("resolve token: %w", err)
} else if token == nil || len(*token) == 0 {
err = ErrNoTokenAvailable
}
if err != nil {
for _, i := range indices {
results[i].Error = err
}
continue
}
// This is technically unnecessary but kept here as a future molly-guard.
if token == nil {
continue
}
for i, idx := range indices {
req := requests[idx]
params, err := r.refreshOne(ctx, provider, *token, req.Row)
results[idx] = RefreshResult{Request: req, Params: params, Error: err}
// If rate-limited, skip remaining rows in this group.
var rlErr *gitprovider.RateLimitError
if errors.As(err, &rlErr) {
for _, remaining := range indices[i+1:] {
results[remaining] = RefreshResult{
Request: requests[remaining],
Error: fmt.Errorf("skipped: %w", rlErr),
}
}
break
}
}
}
return results, nil
}
// refreshOne processes a single row using an already-resolved
// provider and token. This is the old Refresh logic, unchanged.
func (r *Refresher) refreshOne(
ctx context.Context,
provider gitprovider.Provider,
token string,
row database.ChatDiffStatus,
) (*database.UpsertChatDiffStatusParams, error) {
var ref gitprovider.PRRef
var prURL string
if row.Url.Valid && row.Url.String != "" {
// Row already has a PR URL — parse it directly.
parsed, ok := provider.ParsePullRequestURL(row.Url.String)
if !ok {
return nil, xerrors.Errorf("parse pull request URL %q", row.Url.String)
}
ref = parsed
prURL = row.Url.String
} else {
// No PR URL — resolve owner/repo from the remote origin,
// then look up the open PR for this branch.
owner, repo, _, ok := provider.ParseRepositoryOrigin(row.GitRemoteOrigin)
if !ok {
return nil, xerrors.Errorf("parse repository origin %q", row.GitRemoteOrigin)
}
resolved, err := provider.ResolveBranchPullRequest(ctx, token, gitprovider.BranchRef{
Owner: owner,
Repo: repo,
Branch: row.GitBranch,
})
if err != nil {
return nil, xerrors.Errorf("resolve branch pull request: %w", err)
}
if resolved == nil {
// No PR exists yet for this branch.
return nil, nil
}
ref = *resolved
prURL = provider.BuildPullRequestURL(ref)
}
status, err := provider.FetchPullRequestStatus(ctx, token, ref)
if err != nil {
return nil, xerrors.Errorf("fetch pull request status: %w", err)
}
now := r.clock.Now().UTC()
params := &database.UpsertChatDiffStatusParams{
ChatID: row.ChatID,
Url: sql.NullString{String: prURL, Valid: prURL != ""},
PullRequestState: sql.NullString{
String: string(status.State),
Valid: status.State != "",
},
ChangesRequested: status.ChangesRequested,
Additions: status.DiffStats.Additions,
Deletions: status.DiffStats.Deletions,
ChangedFiles: status.DiffStats.ChangedFiles,
RefreshedAt: now,
StaleAt: now.Add(DiffStatusTTL),
}
return params, nil
}
+775
View File
@@ -0,0 +1,775 @@
package gitsync_test
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/quartz"
)
// mockProvider implements gitprovider.Provider with function fields
// so each test can wire only the methods it needs. Any method left
// nil panics with "unexpected call".
type mockProvider struct {
fetchPullRequestStatus func(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error)
resolveBranchPR func(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error)
fetchPullRequestDiff func(ctx context.Context, token string, ref gitprovider.PRRef) (string, error)
fetchBranchDiff func(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error)
parseRepositoryOrigin func(raw string) (string, string, string, bool)
parsePullRequestURL func(raw string) (gitprovider.PRRef, bool)
normalizePullRequestURL func(raw string) string
buildBranchURL func(owner, repo, branch string) string
buildRepositoryURL func(owner, repo string) string
buildPullRequestURL func(ref gitprovider.PRRef) string
}
func (m *mockProvider) FetchPullRequestStatus(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
if m.fetchPullRequestStatus == nil {
panic("unexpected call to FetchPullRequestStatus")
}
return m.fetchPullRequestStatus(ctx, token, ref)
}
func (m *mockProvider) ResolveBranchPullRequest(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) {
if m.resolveBranchPR == nil {
panic("unexpected call to ResolveBranchPullRequest")
}
return m.resolveBranchPR(ctx, token, ref)
}
func (m *mockProvider) FetchPullRequestDiff(ctx context.Context, token string, ref gitprovider.PRRef) (string, error) {
if m.fetchPullRequestDiff == nil {
panic("unexpected call to FetchPullRequestDiff")
}
return m.fetchPullRequestDiff(ctx, token, ref)
}
func (m *mockProvider) FetchBranchDiff(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error) {
if m.fetchBranchDiff == nil {
panic("unexpected call to FetchBranchDiff")
}
return m.fetchBranchDiff(ctx, token, ref)
}
func (m *mockProvider) ParseRepositoryOrigin(raw string) (string, string, string, bool) {
if m.parseRepositoryOrigin == nil {
panic("unexpected call to ParseRepositoryOrigin")
}
return m.parseRepositoryOrigin(raw)
}
func (m *mockProvider) ParsePullRequestURL(raw string) (gitprovider.PRRef, bool) {
if m.parsePullRequestURL == nil {
panic("unexpected call to ParsePullRequestURL")
}
return m.parsePullRequestURL(raw)
}
func (m *mockProvider) NormalizePullRequestURL(raw string) string {
if m.normalizePullRequestURL == nil {
panic("unexpected call to NormalizePullRequestURL")
}
return m.normalizePullRequestURL(raw)
}
func (m *mockProvider) BuildBranchURL(owner, repo, branch string) string {
if m.buildBranchURL == nil {
panic("unexpected call to BuildBranchURL")
}
return m.buildBranchURL(owner, repo, branch)
}
func (m *mockProvider) BuildRepositoryURL(owner, repo string) string {
if m.buildRepositoryURL == nil {
panic("unexpected call to BuildRepositoryURL")
}
return m.buildRepositoryURL(owner, repo)
}
func (m *mockProvider) BuildPullRequestURL(ref gitprovider.PRRef) string {
if m.buildPullRequestURL == nil {
panic("unexpected call to BuildPullRequestURL")
}
return m.buildPullRequestURL(ref)
}
func TestRefresher_WithPRURL(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 5,
ChangedFiles: 3,
},
}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
chatID := uuid.New()
row := database.ChatDiffStatus{
ChatID: chatID,
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
assert.Equal(t, chatID, res.Params.ChatID)
assert.Equal(t, "open", res.Params.PullRequestState.String)
assert.True(t, res.Params.PullRequestState.Valid)
assert.Equal(t, int32(10), res.Params.Additions)
assert.Equal(t, int32(5), res.Params.Deletions)
assert.Equal(t, int32(3), res.Params.ChangedFiles)
// StaleAt should be ~120s after RefreshedAt.
diff := res.Params.StaleAt.Sub(res.Params.RefreshedAt)
assert.InDelta(t, 120, diff.Seconds(), 5)
}
func TestRefresher_BranchResolvesToPR(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
return "org", "repo", "https://github.com/org/repo", true
},
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return &gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 7}, nil
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
buildPullRequestURL: func(_ gitprovider.PRRef) string {
return "https://github.com/org/repo/pull/7"
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
assert.Contains(t, res.Params.Url.String, "pull/7")
assert.True(t, res.Params.Url.Valid)
assert.Equal(t, "open", res.Params.PullRequestState.String)
}
func TestRefresher_BranchNoPRYet(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
return "org", "repo", "https://github.com/org/repo", true
},
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return nil, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.NoError(t, res.Error)
assert.Nil(t, res.Params)
}
func TestRefresher_NoProviderForOrigin(t *testing.T) {
t.Parallel()
providers := func(_ string) gitprovider.Provider { return nil }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://example.com/pr/1", Valid: true},
GitRemoteOrigin: "https://example.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.Contains(t, res.Error.Error(), "no provider")
}
func TestRefresher_TokenResolutionFails(t *testing.T) {
t.Parallel()
var fetchCalled atomic.Bool
mp := &mockProvider{
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
fetchCalled.Store(true)
return nil, errors.New("should not be called")
},
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return nil, errors.New("token lookup failed")
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.False(t, fetchCalled.Load(), "FetchPullRequestStatus should not be called when token resolution fails")
}
func TestRefresher_EmptyToken(t *testing.T) {
t.Parallel()
mp := &mockProvider{}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref(""), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.ErrorIs(t, res.Error, gitsync.ErrNoTokenAvailable)
}
func TestRefresher_ProviderFetchFails(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return nil, errors.New("api error")
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.Contains(t, res.Error.Error(), "api error")
}
func TestRefresher_PRURLParseFailure(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{}, false
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/not-a-pr", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
}
func TestRefresher_BatchGroupsByOwnerAndOrigin(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
var tokenCalls atomic.Int32
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
tokenCalls.Add(1)
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
originA := "https://github.com/org/repo"
originB := "https://gitlab.com/org/repo"
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originA,
GitBranch: "feature-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originA,
GitBranch: "feature-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://gitlab.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originB,
GitBranch: "feature-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
for i, res := range results {
require.NoError(t, res.Error, "result[%d] should not have an error", i)
require.NotNil(t, res.Params, "result[%d] should have params", i)
}
// Two distinct (ownerID, origin) groups → exactly 2 token
// resolution calls.
assert.Equal(t, int32(2), tokenCalls.Load(),
"TokenResolver should be called once per (owner, origin) group")
}
func TestRefresher_UsesInjectedClock(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
fixedTime := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
mClock.Set(fixedTime)
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 5,
ChangedFiles: 3,
},
}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), mClock)
chatID := uuid.New()
row := database.ChatDiffStatus{
ChatID: chatID,
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
// The mock clock is deterministic, so times must be exact.
assert.Equal(t, fixedTime, res.Params.RefreshedAt)
assert.Equal(t, fixedTime.Add(gitsync.DiffStatusTTL), res.Params.StaleAt)
}
func TestRefresher_RateLimitSkipsRemainingInGroup(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
var num int
switch {
case strings.HasSuffix(raw, "/pull/1"):
num = 1
case strings.HasSuffix(raw, "/pull/2"):
num = 2
case strings.HasSuffix(raw, "/pull/3"):
num = 3
default:
return gitprovider.PRRef{}, false
}
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
call := callCount.Add(1)
switch call {
case 1:
// First call succeeds.
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 5,
Deletions: 2,
ChangedFiles: 1,
},
}, nil
case 2:
// Second call hits rate limit.
return nil, &gitprovider.RateLimitError{
RetryAfter: time.Now().Add(60 * time.Second),
}
default:
// Third call should never happen.
t.Fatal("FetchPullRequestStatus called more than 2 times")
return nil, nil
}
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
origin := "https://github.com/org/repo"
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/3", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
// Row 0: success.
assert.NoError(t, results[0].Error)
assert.NotNil(t, results[0].Params)
// Row 1: rate-limited.
require.Error(t, results[1].Error)
var rlErr1 *gitprovider.RateLimitError
assert.True(t, errors.As(results[1].Error, &rlErr1),
"result[1] error should be *RateLimitError")
// Row 2: skipped due to rate limit.
require.Error(t, results[2].Error)
var rlErr2 *gitprovider.RateLimitError
assert.True(t, errors.As(results[2].Error, &rlErr2),
"result[2] error should wrap *RateLimitError")
assert.Contains(t, results[2].Error.Error(), "skipped")
// Provider should have been called exactly twice.
assert.Equal(t, int32(2), callCount.Load(),
"FetchPullRequestStatus should be called exactly 2 times")
}
func TestRefresher_CorrectTokenPerOrigin(t *testing.T) {
t.Parallel()
var tokenCalls atomic.Int32
tokens := func(_ context.Context, _ uuid.UUID, origin string) (*string, error) {
tokenCalls.Add(1)
switch {
case strings.Contains(origin, "github.com"):
return ptr.Ref("gh-public-token"), nil
case strings.Contains(origin, "ghes.corp.com"):
return ptr.Ref("ghe-private-token"), nil
default:
return nil, fmt.Errorf("unexpected origin: %s", origin)
}
}
// Track which token each FetchPullRequestStatus call received,
// keyed by chat ID. We pass the chat ID through the PRRef.Number
// field (unique per request) so FetchPullRequestStatus can
// identify which row it's processing.
var mu sync.Mutex
tokensByPR := make(map[int]string)
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
// Extract a unique PR number from the URL to identify
// each row inside FetchPullRequestStatus.
var num int
switch {
case strings.HasSuffix(raw, "/pull/1"):
num = 1
case strings.HasSuffix(raw, "/pull/2"):
num = 2
case strings.HasSuffix(raw, "/pull/10"):
num = 10
default:
return gitprovider.PRRef{}, false
}
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
},
fetchPullRequestStatus: func(_ context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
mu.Lock()
tokensByPR[ref.Number] = token
mu.Unlock()
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://ghes.corp.com/org/repo/pull/10", Valid: true},
GitRemoteOrigin: "https://ghes.corp.com/org/repo",
GitBranch: "feature-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
for i, res := range results {
require.NoError(t, res.Error, "result[%d] should not have an error", i)
require.NotNil(t, res.Params, "result[%d] should have params", i)
}
// github.com rows (PR #1 and #2) should use the public token.
assert.Equal(t, "gh-public-token", tokensByPR[1],
"github.com PR #1 should use gh-public-token")
assert.Equal(t, "gh-public-token", tokensByPR[2],
"github.com PR #2 should use gh-public-token")
// ghes.corp.com row (PR #10) should use the GHE token.
assert.Equal(t, "ghe-private-token", tokensByPR[10],
"ghes.corp.com PR #10 should use ghe-private-token")
// Token resolution should be called exactly twice — once per
// (owner, origin) group.
assert.Equal(t, int32(2), tokenCalls.Load(),
"TokenResolver should be called once per (owner, origin) group")
}
+255
View File
@@ -0,0 +1,255 @@
package gitsync
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/quartz"
)
const (
// defaultBatchSize is the maximum number of stale rows fetched
// per tick.
defaultBatchSize int32 = 50
// defaultInterval is the polling interval between ticks.
defaultInterval = 10 * time.Second
)
// Store is the narrow DB interface the Worker needs.
type Store interface {
AcquireStaleChatDiffStatuses(
ctx context.Context, limitVal int32,
) ([]database.AcquireStaleChatDiffStatusesRow, error)
BackoffChatDiffStatus(
ctx context.Context, arg database.BackoffChatDiffStatusParams,
) error
UpsertChatDiffStatus(
ctx context.Context, arg database.UpsertChatDiffStatusParams,
) (database.ChatDiffStatus, error)
UpsertChatDiffStatusReference(
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
) (database.ChatDiffStatus, error)
GetChatsByOwnerID(
ctx context.Context, arg database.GetChatsByOwnerIDParams,
) ([]database.Chat, error)
}
// EventPublisher notifies the frontend of diff status changes.
type PublishDiffStatusChangeFunc func(ctx context.Context, chatID uuid.UUID) error
// Worker is a background loop that periodically refreshes stale
// chat diff statuses by delegating to a Refresher.
type Worker struct {
store Store
refresher *Refresher
publishDiffStatusChangeFn PublishDiffStatusChangeFunc
clock quartz.Clock
logger slog.Logger
batchSize int32
interval time.Duration
done chan struct{}
}
// NewWorker creates a Worker with default batch size and interval.
func NewWorker(
store Store,
refresher *Refresher,
publisher PublishDiffStatusChangeFunc,
clock quartz.Clock,
logger slog.Logger,
) *Worker {
return &Worker{
store: store,
refresher: refresher,
publishDiffStatusChangeFn: publisher,
clock: clock,
logger: logger,
batchSize: defaultBatchSize,
interval: defaultInterval,
done: make(chan struct{}),
}
}
// Start launches the background loop. It blocks until ctx is
// cancelled, then closes w.done.
func (w *Worker) Start(ctx context.Context) {
defer close(w.done)
ticker := w.clock.NewTicker(w.interval, "gitsync", "worker")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.tick(ctx)
}
}
}
// Done returns a channel that is closed when the worker exits.
func (w *Worker) Done() <-chan struct{} {
return w.done
}
func chatDiffStatusFromRow(row database.AcquireStaleChatDiffStatusesRow) database.ChatDiffStatus {
return database.ChatDiffStatus{
ChatID: row.ChatID,
Url: row.Url,
PullRequestState: row.PullRequestState,
ChangesRequested: row.ChangesRequested,
Additions: row.Additions,
Deletions: row.Deletions,
ChangedFiles: row.ChangedFiles,
RefreshedAt: row.RefreshedAt,
StaleAt: row.StaleAt,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
GitBranch: row.GitBranch,
GitRemoteOrigin: row.GitRemoteOrigin,
}
}
func (w *Worker) tick(ctx context.Context) {
// Set a context equal to w.interval so that we do not hold up processing due to
// random unicorn-related events.
ctx, cancel := context.WithTimeout(ctx, w.interval)
defer cancel()
acquiredRows, err := w.store.AcquireStaleChatDiffStatuses(ctx, w.batchSize)
if err != nil {
w.logger.Warn(ctx, "acquire stale chat diff statuses",
slog.Error(err))
return
}
if len(acquiredRows) == 0 {
return
}
// Build refresh requests directly from acquired rows.
requests := make([]RefreshRequest, 0, len(acquiredRows))
for _, row := range acquiredRows {
requests = append(requests, RefreshRequest{
Row: chatDiffStatusFromRow(row),
OwnerID: row.OwnerID,
})
}
results, err := w.refresher.Refresh(ctx, requests)
if err != nil {
w.logger.Warn(ctx, "batch refresh chat diff statuses",
slog.Error(err))
return
}
for _, res := range results {
if res.Error != nil {
w.logger.Debug(ctx, "refresh chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(res.Error))
// Back off so the row isn't retried immediately.
if err := w.store.BackoffChatDiffStatus(ctx,
database.BackoffChatDiffStatusParams{
ChatID: res.Request.Row.ChatID,
StaleAt: w.clock.Now().UTC().Add(DiffStatusTTL),
},
); err != nil {
w.logger.Warn(ctx, "backoff failed chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
}
continue
}
if res.Params == nil {
// No PR yet — skip.
continue
}
if _, err := w.store.UpsertChatDiffStatus(ctx, *res.Params); err != nil {
w.logger.Warn(ctx, "upsert refreshed chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
continue
}
if w.publishDiffStatusChangeFn != nil {
if err := w.publishDiffStatusChangeFn(ctx, res.Request.Row.ChatID); err != nil {
w.logger.Debug(ctx, "publish diff status change",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
}
}
}
}
// MarkStale persists the git ref on all chats for a workspace,
// setting stale_at to the past so the next tick picks them up.
// Publishes a diff status event for each affected chat.
// Called from workspaceagents handlers. No goroutines spawned.
func (w *Worker) MarkStale(
ctx context.Context,
workspaceID, ownerID uuid.UUID,
branch, origin string,
) {
if branch == "" || origin == "" {
return
}
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: ownerID,
})
if err != nil {
w.logger.Warn(ctx, "list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err))
return
}
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
_, err := w.store.UpsertChatDiffStatusReference(ctx,
database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: branch,
GitRemoteOrigin: origin,
StaleAt: w.clock.Now().Add(-time.Second),
Url: sql.NullString{},
},
)
if err != nil {
w.logger.Warn(ctx, "store git ref on chat diff status",
slog.F("chat_id", chat.ID),
slog.F("workspace_id", workspaceID),
slog.Error(err))
continue
}
// Notify the frontend immediately so the UI shows the
// branch info even before the worker refreshes PR data.
if w.publishDiffStatusChangeFn != nil {
if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil {
w.logger.Debug(ctx, "publish diff status after mark stale",
slog.F("chat_id", chat.ID), slog.Error(pubErr))
}
}
}
}
// filterChatsByWorkspaceID returns only chats associated with
// the given workspace.
func filterChatsByWorkspaceID(
chats []database.Chat,
workspaceID uuid.UUID,
) []database.Chat {
filtered := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
continue
}
filtered = append(filtered, chat)
}
return filtered
}
+744
View File
@@ -0,0 +1,744 @@
package gitsync_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// testRefresherCfg configures newTestRefresher.
type testRefresherCfg struct {
resolveBranchPR func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)
fetchPRStatus func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error)
}
type testRefresherOpt func(*testRefresherCfg)
func withResolveBranchPR(f func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)) testRefresherOpt {
return func(c *testRefresherCfg) { c.resolveBranchPR = f }
}
// newTestRefresher creates a Refresher backed by mock
// provider/token resolvers. The provider recognises any origin,
// resolves branches to a canned PR, and returns a canned PRStatus.
func newTestRefresher(t *testing.T, clk quartz.Clock, opts ...testRefresherOpt) *gitsync.Refresher {
t.Helper()
cfg := testRefresherCfg{
resolveBranchPR: func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
},
fetchPRStatus: func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 3,
ChangedFiles: 2,
},
}, nil
},
}
for _, o := range opts {
o(&cfg)
}
prov := &mockProvider{
parseRepositoryOrigin: func(string) (string, string, string, bool) {
return "owner", "repo", "https://github.com/owner/repo", true
},
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, raw != ""
},
resolveBranchPR: cfg.resolveBranchPR,
fetchPullRequestStatus: cfg.fetchPRStatus,
buildPullRequestURL: func(ref gitprovider.PRRef) string {
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
},
}
providers := func(string) gitprovider.Provider { return prov }
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
return ptr.Ref("tok"), nil
}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
return gitsync.NewRefresher(providers, tokens, logger, clk)
}
// makeAcquiredRow returns an AcquireStaleChatDiffStatusesRow with
// a non-empty branch/origin so the Refresher goes through the
// branch-resolution path.
func makeAcquiredRow(chatID, ownerID uuid.UUID) database.AcquireStaleChatDiffStatusesRow {
return database.AcquireStaleChatDiffStatusesRow{
ChatID: chatID,
GitBranch: "feature",
GitRemoteOrigin: "https://github.com/owner/repo",
StaleAt: time.Now().Add(-time.Minute),
OwnerID: ownerID,
}
}
// tickOnce traps the worker's NewTicker call, starts the worker,
// fires one tick, waits for it to finish by observing the given
// tickDone channel, then shuts the worker down. The tickDone
// channel must be closed when the last expected operation in the
// tick completes. For tests where the tick does nothing (e.g. 0
// stale rows or store error), tickDone should be closed inside
// acquireStaleChatDiffStatuses.
func tickOnce(
ctx context.Context,
t *testing.T,
mClock *quartz.Mock,
worker *gitsync.Worker,
tickDone <-chan struct{},
) {
t.Helper()
trap := mClock.Trap().NewTicker("gitsync", "worker")
defer trap.Close()
workerCtx, cancel := context.WithCancel(ctx)
defer cancel()
go worker.Start(workerCtx)
// Wait for the worker to create its ticker.
trap.MustWait(ctx).MustRelease(ctx)
// Fire one tick. The waiter resolves when the channel receive
// completes, not when w.tick() returns, so we use tickDone to
// know when to proceed.
_, w := mClock.AdvanceNext()
w.MustWait(ctx)
// Wait for the tick's business logic to finish.
select {
case <-tickDone:
case <-ctx.Done():
t.Fatal("timed out waiting for tick to complete")
}
cancel()
<-worker.Done()
}
func TestWorker_SkipsFreshRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
// No stale rows — tick returns immediately.
close(tickDone)
return nil, nil
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_LimitsToNRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
var capturedLimit atomic.Int32
var upsertCount atomic.Int32
ownerID := uuid.New()
const numRows = 5
tickDone := make(chan struct{})
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
for i := range rows {
rows[i] = makeAcquiredRow(uuid.New(), ownerID)
}
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
capturedLimit.Store(limitVal)
return rows, nil
})
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(numRows)
pub := func(_ context.Context, _ uuid.UUID) error {
if upsertCount.Load() == numRows {
close(tickDone)
}
return nil
}
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
// The default batch size is 50.
assert.Equal(t, int32(50), capturedLimit.Load())
assert.Equal(t, int32(numRows), upsertCount.Load())
}
func TestWorker_RefresherReturnsNilNil_SkipsUpsert(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chatID := uuid.New()
ownerID := uuid.New()
// When the Refresher returns (nil, nil) the worker skips the
// upsert and publish. We signal tickDone from the refresher
// mock since that is the last operation before the tick
// returns.
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{makeAcquiredRow(chatID, ownerID)}, nil)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// ResolveBranchPullRequest returns nil → Refresher returns
// (nil, nil).
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
close(tickDone)
return nil, nil
},
))
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_RefresherError_BacksOffRow(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chat1 := uuid.New()
chat2 := uuid.New()
ownerID := uuid.New()
var upsertCount atomic.Int32
var publishCount atomic.Int32
var backoffCount atomic.Int32
var mu sync.Mutex
var backoffArgs []database.BackoffChatDiffStatusParams
tickDone := make(chan struct{})
var closeOnce sync.Once
// Two rows processed: one fails (backoff), one succeeds
// (upsert+publish). Both must finish before we close tickDone.
var terminalOps atomic.Int32
signalIfDone := func() {
if terminalOps.Add(1) == 2 {
closeOnce.Do(func() { close(tickDone) })
}
}
mClock := quartz.NewMock(t)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{
makeAcquiredRow(chat1, ownerID),
makeAcquiredRow(chat2, ownerID),
}, nil)
store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error {
backoffCount.Add(1)
mu.Lock()
backoffArgs = append(backoffArgs, arg)
mu.Unlock()
signalIfDone()
return nil
})
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
})
pub := func(_ context.Context, _ uuid.UUID) error {
// Only the successful row publishes.
publishCount.Add(1)
signalIfDone()
return nil
}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Fail ResolveBranchPullRequest for the first call, succeed
// for the second.
var callCount atomic.Int32
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
n := callCount.Add(1)
if n == 1 {
return nil, fmt.Errorf("simulated provider error")
}
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
},
))
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
// BackoffChatDiffStatus was called for the failed row.
assert.Equal(t, int32(1), backoffCount.Load())
mu.Lock()
require.Len(t, backoffArgs, 1)
assert.Equal(t, chat1, backoffArgs[0].ChatID)
// stale_at should be approximately clock.Now() + DiffStatusTTL (120s).
expectedStaleAt := mClock.Now().UTC().Add(gitsync.DiffStatusTTL)
assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second)
mu.Unlock()
// UpsertChatDiffStatus was called for the successful row.
assert.Equal(t, int32(1), upsertCount.Load())
// PublishDiffStatusChange was called only for the successful row.
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_UpsertError_ContinuesNextRow(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chat1 := uuid.New()
chat2 := uuid.New()
ownerID := uuid.New()
var publishCount atomic.Int32
tickDone := make(chan struct{})
var closeOnce sync.Once
var mu sync.Mutex
upsertedChatIDs := make(map[uuid.UUID]struct{})
// We have 2 rows. The upsert for chat1 fails; the upsert
// for chat2 succeeds and publishes. Because goroutines run
// concurrently we don't know which finishes last, so we
// track the total number of "terminal" events (upsert error
// + publish success) and close tickDone when both have
// occurred.
var terminalOps atomic.Int32
signalIfDone := func() {
if terminalOps.Add(1) == 2 {
closeOnce.Do(func() { close(tickDone) })
}
}
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{
makeAcquiredRow(chat1, ownerID),
makeAcquiredRow(chat2, ownerID),
}, nil)
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
if arg.ChatID == chat1 {
// Terminal event for the failing row.
signalIfDone()
return database.ChatDiffStatus{}, fmt.Errorf("db write error")
}
mu.Lock()
upsertedChatIDs[arg.ChatID] = struct{}{}
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub := func(_ context.Context, _ uuid.UUID) error {
publishCount.Add(1)
// Terminal event for the successful row.
signalIfDone()
return nil
}
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
mu.Lock()
_, gotChat2 := upsertedChatIDs[chat2]
mu.Unlock()
assert.True(t, gotChat2, "chat2 should have been upserted")
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_RespectsShutdown(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return(nil, nil).AnyTimes()
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
trap := mClock.Trap().NewTicker("gitsync", "worker")
defer trap.Close()
workerCtx, cancel := context.WithCancel(ctx)
go worker.Start(workerCtx)
// Wait for ticker creation so the worker is running.
trap.MustWait(ctx).MustRelease(ctx)
// Cancel immediately.
cancel()
select {
case <-worker.Done():
// Success — worker shut down.
case <-ctx.Done():
t.Fatal("timed out waiting for worker to shut down")
}
}
func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
chat1 := uuid.New()
chat2 := uuid.New()
chatOther := uuid.New()
var mu sync.Mutex
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
var publishedIDs []uuid.UUID
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub := func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
publishedIDs = append(publishedIDs, chatID)
mu.Unlock()
return nil
}
mClock := quartz.NewMock(t)
now := mClock.Now()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
mu.Lock()
defer mu.Unlock()
require.Len(t, upsertRefCalls, 2)
for _, call := range upsertRefCalls {
assert.Equal(t, "feature", call.GitBranch)
assert.Equal(t, "https://github.com/owner/repo", call.GitRemoteOrigin)
assert.True(t, call.StaleAt.Before(now),
"stale_at should be in the past, got %v vs now %v", call.StaleAt, now)
assert.Equal(t, sql.NullString{}, call.Url)
}
require.Len(t, publishedIDs, 2)
assert.ElementsMatch(t, []uuid.UUID{chat1, chat2}, publishedIDs)
}
func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
}
func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
chat1 := uuid.New()
chat2 := uuid.New()
var publishCount atomic.Int32
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
}, nil)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
if arg.ChatID == chat1 {
return database.ChatDiffStatus{}, fmt.Errorf("upsert ref error")
}
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub := func(_ context.Context, _ uuid.UUID) error {
publishCount.Add(1)
return nil
}
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("db error"))
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
}
func TestWorker_TickStoreError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
close(tickDone)
return nil, fmt.Errorf("database unavailable")
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
t.Parallel()
tests := []struct {
name string
branch string
origin string
}{
{"both empty", "", ""},
{"branch empty", "", "https://github.com/x/y"},
{"origin empty", "main", ""},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
})
}
}
// TestWorker exercises the worker tick against a
// real PostgreSQL database to verify that the SQL queries, foreign key
// constraints, and upsert logic work end-to-end.
func TestWorker(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// 1. Real database store.
db, _ := dbtestutil.NewDB(t)
// 2. Create a user (FK for chats).
user := dbgen.User(t, db, database.User{})
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
})
require.NoError(t, err)
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
DisplayName: "Test Model",
Enabled: true,
ContextLimit: 100000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "integration-test",
})
require.NoError(t, err)
// 4. Seed a stale diff status row so the worker picks it up.
_, err = db.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: "feature",
GitRemoteOrigin: "https://github.com/o/r",
StaleAt: time.Now().Add(-time.Minute),
Url: sql.NullString{},
})
require.NoError(t, err)
// 5. Mock refresher returns a canned PR status.
mClock := quartz.NewMock(t)
refresher := newTestRefresher(t, mClock)
// 6. Track publish calls.
var publishCount atomic.Int32
tickDone := make(chan struct{})
pub := func(_ context.Context, chatID uuid.UUID) error {
assert.Equal(t, chat.ID, chatID)
if publishCount.Add(1) == 1 {
close(tickDone)
}
return nil
}
// 7. Create and run the worker for one tick.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
worker := gitsync.NewWorker(db, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
// 8. Assert publisher was called.
require.Equal(t, int32(1), publishCount.Load())
// 9. Read back and verify persisted fields.
status, err := db.GetChatDiffStatusByChatID(ctx, chat.ID)
require.NoError(t, err)
// The mock resolveBranchPR returns PRRef{Owner: "o", Repo: "r", Number: 1}
// and buildPullRequestURL formats it as https://github.com/o/r/pull/1.
assert.Equal(t, "https://github.com/o/r/pull/1", status.Url.String)
assert.True(t, status.Url.Valid)
assert.Equal(t, string(gitprovider.PRStateOpen), status.PullRequestState.String)
assert.True(t, status.PullRequestState.Valid)
assert.Equal(t, int32(10), status.Additions)
assert.Equal(t, int32(3), status.Deletions)
assert.Equal(t, int32(2), status.ChangedFiles)
assert.True(t, status.RefreshedAt.Valid, "refreshed_at should be set")
// The mock clock's Now() + DiffStatusTTL determines stale_at.
expectedStaleAt := mClock.Now().Add(gitsync.DiffStatusTTL)
assert.WithinDuration(t, expectedStaleAt, status.StaleAt, time.Second)
}
+10 -20
View File
@@ -1835,18 +1835,6 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
Branch: strings.TrimSpace(query.Get("git_branch")), Branch: strings.TrimSpace(query.Get("git_branch")),
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")), RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
} }
var chatID uuid.NullUUID
if rawChatID := query.Get("chat_id"); rawChatID != "" {
parsed, err := uuid.Parse(rawChatID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat_id.",
Detail: err.Error(),
})
return
}
chatID = uuid.NullUUID{UUID: parsed, Valid: true}
}
// Either match or configID must be provided! // Either match or configID must be provided!
match := query.Get("match") match := query.Get("match")
if match == "" { if match == "" {
@@ -1940,11 +1928,12 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return return
} }
// Persist git refs as soon as the agent requests external auth so branch // MarkStale will trigger a refresh by coderd/gitsync. This allows us to
// persist git refs as soon as the agent requests external auth so branch
// context is retained even if the flow requires an out-of-band login. // context is retained even if the flow requires an out-of-band login.
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" { if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
//nolint:gocritic // System context required to persist chat git refs. //nolint:gocritic // Chat processor context required for cross-user chat lookup
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef) api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
} }
var previousToken *database.ExternalAuthLink var previousToken *database.ExternalAuthLink
@@ -1960,7 +1949,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return return
} }
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef) api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
} }
// This is the URL that will redirect the user with a state token. // This is the URL that will redirect the user with a state token.
@@ -2018,11 +2007,10 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
}) })
return return
} }
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp) httpapi.Write(ctx, rw, http.StatusOK, resp)
} }
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) { func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
// Since we're ticking frequently and this sign-in operation is rare, // Since we're ticking frequently and this sign-in operation is rare,
// we are OK with polling to avoid the complexity of pubsub. // we are OK with polling to avoid the complexity of pubsub.
ticker, done := api.NewTicker(time.Second) ticker, done := api.NewTicker(time.Second)
@@ -2092,7 +2080,9 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
}) })
return return
} }
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef) // MarkStale will trigger a refresh by coderd/gitsync.
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
httpapi.Write(ctx, rw, http.StatusOK, resp) httpapi.Write(ctx, rw, http.StatusOK, resp)
return return
} }
+4
View File
@@ -980,6 +980,10 @@ type ExternalAuthConfig struct {
// 'Username for "https://github.com":' // 'Username for "https://github.com":'
// And sending it to the Coder server to match against the Regex. // And sending it to the Coder server to match against the Regex.
Regex string `json:"regex" yaml:"regex"` Regex string `json:"regex" yaml:"regex"`
// APIBaseURL is the base URL for provider REST API calls
// (e.g., "https://api.github.com" for GitHub). Derived from
// defaults when not explicitly configured.
APIBaseURL string `json:"api_base_url" yaml:"api_base_url"`
// DisplayName is shown in the UI to identify the auth config. // DisplayName is shown in the UI to identify the auth config.
DisplayName string `json:"display_name" yaml:"display_name"` DisplayName string `json:"display_name" yaml:"display_name"`
// DisplayIcon is a URL to an icon to display in the UI. // DisplayIcon is a URL to an icon to display in the UI.
+1
View File
@@ -22,6 +22,7 @@ externalAuthProviders:
mcp_tool_allow_regex: .* mcp_tool_allow_regex: .*
mcp_tool_deny_regex: create_gist mcp_tool_deny_regex: create_gist
regex: ^https://example.com/.*$ regex: ^https://example.com/.*$
api_base_url: ""
display_name: GitHub display_name: GitHub
display_icon: /static/icons/github.svg display_icon: /static/icons/github.svg
code_challenge_methods_supported: code_challenge_methods_supported:
+1
View File
@@ -279,6 +279,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
"external_auth": { "external_auth": {
"value": [ "value": [
{ {
"api_base_url": "string",
"app_install_url": "string", "app_install_url": "string",
"app_installations_url": "string", "app_installations_url": "string",
"auth_url": "string", "auth_url": "string",
+21 -16
View File
@@ -2786,6 +2786,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"external_auth": { "external_auth": {
"value": [ "value": [
{ {
"api_base_url": "string",
"app_install_url": "string", "app_install_url": "string",
"app_installations_url": "string", "app_installations_url": "string",
"auth_url": "string", "auth_url": "string",
@@ -3357,6 +3358,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"external_auth": { "external_auth": {
"value": [ "value": [
{ {
"api_base_url": "string",
"app_install_url": "string", "app_install_url": "string",
"app_installations_url": "string", "app_installations_url": "string",
"auth_url": "string", "auth_url": "string",
@@ -4104,6 +4106,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
```json ```json
{ {
"api_base_url": "string",
"app_install_url": "string", "app_install_url": "string",
"app_installations_url": "string", "app_installations_url": "string",
"auth_url": "string", "auth_url": "string",
@@ -4133,22 +4136,23 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
### Properties ### Properties
| Name | Type | Required | Restrictions | Description | | Name | Type | Required | Restrictions | Description |
|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------| |------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `app_install_url` | string | false | | | | `api_base_url` | string | false | | Api base URL is the base URL for provider REST API calls (e.g., "https://api.github.com" for GitHub). Derived from defaults when not explicitly configured. |
| `app_installations_url` | string | false | | | | `app_install_url` | string | false | | |
| `auth_url` | string | false | | | | `app_installations_url` | string | false | | |
| `client_id` | string | false | | | | `auth_url` | string | false | | |
| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". | | `client_id` | string | false | | |
| `device_code_url` | string | false | | | | `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". |
| `device_flow` | boolean | false | | | | `device_code_url` | string | false | | |
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. | | `device_flow` | boolean | false | | |
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. | | `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. | | `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
| `mcp_tool_allow_regex` | string | false | | | | `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
| `mcp_tool_deny_regex` | string | false | | | | `mcp_tool_allow_regex` | string | false | | |
| `mcp_url` | string | false | | | | `mcp_tool_deny_regex` | string | false | | |
| `no_refresh` | boolean | false | | | | `mcp_url` | string | false | | |
| `no_refresh` | boolean | false | | |
|`regex`|string|false||Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type. |`regex`|string|false||Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type.
Git clone makes use of this by parsing the URL from: 'Username for "https://github.com":' And sending it to the Coder server to match against the Regex.| Git clone makes use of this by parsing the URL from: 'Username for "https://github.com":' And sending it to the Coder server to match against the Regex.|
|`revoke_url`|string|false||| |`revoke_url`|string|false|||
@@ -14182,6 +14186,7 @@ None
{ {
"value": [ "value": [
{ {
"api_base_url": "string",
"app_install_url": "string", "app_install_url": "string",
"app_installations_url": "string", "app_installations_url": "string",
"auth_url": "string", "auth_url": "string",
+6
View File
@@ -2690,6 +2690,12 @@ export interface ExternalAuthConfig {
* And sending it to the Coder server to match against the Regex. * And sending it to the Coder server to match against the Regex.
*/ */
readonly regex: string; readonly regex: string;
/**
* APIBaseURL is the base URL for provider REST API calls
* (e.g., "https://api.github.com" for GitHub). Derived from
* defaults when not explicitly configured.
*/
readonly api_base_url: string;
/** /**
* DisplayName is shown in the UI to identify the auth config. * DisplayName is shown in the UI to identify the auth config.
*/ */
@@ -12,6 +12,7 @@ const meta: Meta<typeof ExternalAuthSettingsPageView> = {
type: "GitHub", type: "GitHub",
client_id: "client_id", client_id: "client_id",
regex: "regex", regex: "regex",
api_base_url: "",
auth_url: "", auth_url: "",
token_url: "", token_url: "",
validate_url: "", validate_url: "",