Files
coder/coderd/externalauth/gitprovider/gitlab.go
T
Cian Johnston 579daaff70 feat: add GitLab support to coderd/externalauth/gitprovider
Fixes CODAGT-146

Add GitLab support to the gitprovider package for gitsync/chatd PR
diff flows. This is a squashed stack of 3 PRs:

#25651 - refactor(coderd/externalauth): prepare gitprovider for multi-provider support
- Change gitprovider.New to return (Provider, error)
- Extract shared helpers (parseRetryAfter, checkRateLimitError,
  countDiffLines, escapePathPreserveSlashes) from github.go
- Update all callers (db2sdk, exp_chats, gitsync) for new signature
- Add error logging for provider construction failures
- Thread context through provider resolution

#25652 - feat(coderd/externalauth/gitprovider): add GitLab provider
- Implement full Provider interface: FetchPullRequestStatus,
  FetchPullRequestDiff, FetchBranchDiff, ResolveBranchPullRequest
- Handle nested groups, forks, and self-hosted instances
- Rate limit detection on both library and raw HTTP paths
- URL parsing/building with NormalizePullRequestURL support
- Unit tests covering error paths, URL parsing, state mapping
- Document GitLab configuration and known limitations

#25653 - test(coderd/externalauth/gitprovider): add GitLab VCR integration tests
- FetchPullRequestStatus: 4 fixtures (open, conflicts, merged, closed)
- FetchPullRequestDiff: 4 fixtures
- FetchBranchDiff: 3 fixtures (open, deleted, fork)
- ResolveBranchPullRequest: 3 fixtures
- go-vcr cassettes with sanitized GitLab API responses
2026-05-25 17:41:02 +01:00

682 lines
19 KiB
Go

package gitprovider
import (
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
gitlab "gitlab.com/gitlab-org/api/client-go"
"golang.org/x/xerrors"
"github.com/coder/quartz"
)
type gitlabProvider struct {
webBaseURL string
client *gitlab.Client
clock quartz.Clock
}
func newGitLab(baseURL string, httpClient *http.Client, clock quartz.Clock) (*gitlabProvider, error) {
if baseURL == "" {
baseURL = "https://gitlab.com"
}
baseURL = strings.TrimRight(baseURL, "/")
baseURL = strings.TrimSuffix(baseURL, "/api/v4")
if httpClient == nil {
httpClient = http.DefaultClient
}
client, err := gitlab.NewClient("",
gitlab.WithBaseURL(baseURL),
gitlab.WithHTTPClient(httpClient),
gitlab.WithoutRetries(),
)
if err != nil {
return nil, xerrors.Errorf("create gitlab client: %w", err)
}
return &gitlabProvider{
webBaseURL: baseURL,
client: client,
clock: clock,
}, nil
}
var _ Provider = (*gitlabProvider)(nil)
// webHost returns the hostname (with port if present) of the GitLab web URL.
func (g *gitlabProvider) webHost() string {
u, err := url.Parse(g.webBaseURL)
if err != nil {
return "gitlab.com"
}
return u.Host
}
// reqOpts returns per-request options for authentication and context.
func reqOpts(ctx context.Context, token string) []gitlab.RequestOptionFunc {
opts := []gitlab.RequestOptionFunc{gitlab.WithContext(ctx)}
if token != "" {
opts = append(opts, gitlab.WithToken(gitlab.OAuthToken, token))
}
return opts
}
// gitLabPID returns the full project path (owner/repo) for use as a pid.
// The library handles URL encoding internally.
func gitLabPID(owner, repo string) string {
return owner + "/" + repo
}
func (g *gitlabProvider) FetchPullRequestStatus(
ctx context.Context,
token string,
ref PRRef,
) (*PRStatus, error) {
pid := gitLabPID(ref.Owner, ref.Repo)
opts := reqOpts(ctx, token)
// Fetch merge request details.
mr, _, err := g.client.MergeRequests.GetMergeRequest(pid, int64(ref.Number), nil, opts...)
if err != nil {
return nil, g.wrapError(err, "get merge request")
}
// Fetch approvals.
approvals, _, err := g.client.MergeRequests.GetMergeRequestApprovals(pid, int64(ref.Number), opts...)
if err != nil {
return nil, g.wrapError(err, "get merge request approvals")
}
// Fetch commits to get the commit count.
var totalCommits int32
commits, resp, err := g.client.MergeRequests.GetMergeRequestCommits(
pid, int64(ref.Number),
&gitlab.GetMergeRequestCommitsOptions{ListOptions: gitlab.ListOptions{PerPage: 100}},
opts...,
)
if err != nil {
return nil, g.wrapError(err, "get merge request commits")
}
if resp.TotalItems > 0 {
totalCommits = int32(resp.TotalItems)
} else {
totalCommits = int32(len(commits))
}
// Fetch MR diffs to compute additions/deletions.
// The commits endpoint does not return per-commit stats, so we
// count +/- lines from the unified diff returned by this endpoint.
var additions, deletions int32
diffs, _, err := g.client.MergeRequests.ListMergeRequestDiffs(
pid, int64(ref.Number),
// NOTE: fetches a single page of up to 100 diffs. MRs with more than
// 100 changed files will have correct ChangedFiles (from MR metadata)
// but undercounted Additions/Deletions. Pagination is omitted because
// the gitsync worker only uses ChangedFiles for its heuristics today.
&gitlab.ListMergeRequestDiffsOptions{ListOptions: gitlab.ListOptions{PerPage: 100}},
opts...,
)
if err != nil {
return nil, g.wrapError(err, "list merge request diffs")
}
for _, d := range diffs {
diffAdditions, diffDeletions := countDiffLines(d.Diff)
additions += diffAdditions
deletions += diffDeletions
}
// Map GitLab state to normalized state.
state := mapGitLabState(mr.State)
// Use diff_refs.head_sha if available, fall back to top-level sha.
headSHA := cmp.Or(mr.DiffRefs.HeadSha, mr.SHA)
// Parse changes_count (it's a string, possibly "1000+").
var changedFiles int32
if mr.ChangesCount != "" {
trimmed := strings.TrimSuffix(mr.ChangesCount, "+")
if n, err := strconv.Atoi(trimmed); err == nil {
changedFiles = int32(n)
}
}
// TODO(CODAGT-440): These fields have semantic gaps vs the GitHub
// provider. GitLab's "Approved" is threshold-based (not "at least one
// approval and no changes requested"), ChangesRequested has no GitLab
// equivalent, and ReviewerCount only counts approvers.
reviewerCount := int32(len(approvals.ApprovedBy))
var authorLogin, authorAvatarURL string
if mr.Author != nil {
authorLogin = mr.Author.Username
authorAvatarURL = mr.Author.AvatarURL
}
return &PRStatus{
Title: mr.Title,
State: state,
Draft: mr.Draft,
HeadSHA: headSHA,
HeadBranch: mr.SourceBranch,
DiffStats: DiffStats{
Additions: additions,
Deletions: deletions,
ChangedFiles: changedFiles,
},
ChangesRequested: false,
Approved: approvals.Approved,
ReviewerCount: reviewerCount,
AuthorLogin: authorLogin,
AuthorAvatarURL: authorAvatarURL,
BaseBranch: mr.TargetBranch,
PRNumber: int(mr.IID),
Commits: totalCommits,
FetchedAt: g.clock.Now().UTC(),
}, nil
}
func (g *gitlabProvider) ResolveBranchPullRequest(
ctx context.Context,
token string,
ref BranchRef,
) (*PRRef, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return nil, nil
}
pid := gitLabPID(ref.Owner, ref.Repo)
opts := reqOpts(ctx, token)
mrs, _, err := g.client.MergeRequests.ListProjectMergeRequests(pid, &gitlab.ListProjectMergeRequestsOptions{
ListOptions: gitlab.ListOptions{PerPage: 1},
SourceBranch: gitlab.Ptr(ref.Branch),
State: gitlab.Ptr("opened"),
OrderBy: gitlab.Ptr("updated_at"),
Sort: gitlab.Ptr("desc"),
}, opts...)
if err != nil {
return nil, g.wrapError(err, "list merge requests by branch")
}
if len(mrs) == 0 {
return nil, nil
}
prRef, ok := g.ParsePullRequestURL(mrs[0].WebURL)
if !ok {
// Fallback: construct from known owner/repo and returned IID.
return &PRRef{
Owner: ref.Owner,
Repo: ref.Repo,
Number: int(mrs[0].IID),
}, nil
}
return &prRef, nil
}
func (g *gitlabProvider) FetchPullRequestDiff(
ctx context.Context,
token string,
ref PRRef,
) (string, error) {
pid := gitLabPID(ref.Owner, ref.Repo)
// Make a direct HTTP request instead of using the library's
// ShowMergeRequestRawDiffs, which reads the entire response
// into memory before returning. We use io.LimitReader to
// bound memory and reject diffs exceeding MaxDiffSize.
rawURL := fmt.Sprintf("%sprojects/%s/merge_requests/%d/raw_diffs",
g.client.BaseURL().String(), url.PathEscape(pid), ref.Number)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", g.wrapError(err, "create raw diffs request")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.client.HTTPClient().Do(req)
if err != nil {
return "", g.wrapError(err, "get merge request raw diffs")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if rlErr := checkRateLimitError(resp, g.clock, "RateLimit-Reset"); rlErr != nil {
return "", rlErr
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", g.wrapError(
xerrors.Errorf("unexpected status %d", resp.StatusCode),
"get merge request raw diffs",
)
}
return "", g.wrapError(
xerrors.Errorf("unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))),
"get merge request raw diffs",
)
}
buf, err := io.ReadAll(io.LimitReader(resp.Body, MaxDiffSize+1))
if err != nil {
return "", g.wrapError(err, "read merge request raw diffs")
}
if len(buf) > MaxDiffSize {
return "", ErrDiffTooLarge
}
return string(buf), nil
}
// compareResponse is the subset of GitLab's compare endpoint response
// that we need. We decode manually (instead of using the library) so
// we can bound memory with io.LimitReader before JSON parsing.
type compareResponse struct {
Diffs []struct {
Diff string `json:"diff"`
OldPath string `json:"old_path"`
NewPath string `json:"new_path"`
NewFile bool `json:"new_file"`
DeletedFile bool `json:"deleted_file"`
RenamedFile bool `json:"renamed_file"`
Collapsed bool `json:"collapsed"`
TooLarge bool `json:"too_large"`
} `json:"diffs"`
CompareTimeout bool `json:"compare_timeout"`
}
func (g *gitlabProvider) FetchBranchDiff(
ctx context.Context,
token string,
ref BranchRef,
) (string, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return "", nil
}
pid := gitLabPID(ref.Owner, ref.Repo)
opts := reqOpts(ctx, token)
// Get the default branch from the project.
project, _, err := g.client.Projects.GetProject(pid, nil, opts...)
if err != nil {
return "", g.wrapError(err, "get project")
}
defaultBranch := strings.TrimSpace(project.DefaultBranch)
if defaultBranch == "" {
return "", xerrors.New("gitlab project default branch is empty")
}
// Use raw HTTP with io.LimitReader to bound memory. The library's
// Compare() decodes the full response before returning, which
// would allow a maliciously large diff to OOM the process.
compareURL := fmt.Sprintf("%sprojects/%s/repository/compare?from=%s&to=%s&unidiff=true",
g.client.BaseURL().String(),
url.PathEscape(pid),
url.QueryEscape(defaultBranch),
url.QueryEscape(ref.Branch),
)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, compareURL, nil)
if err != nil {
return "", g.wrapError(err, "create compare request")
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.client.HTTPClient().Do(req)
if err != nil {
return "", g.wrapError(err, "compare branches")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if rlErr := checkRateLimitError(resp, g.clock, "RateLimit-Reset"); rlErr != nil {
return "", rlErr
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", g.wrapError(
xerrors.Errorf("unexpected status %d", resp.StatusCode),
"compare branches",
)
}
return "", g.wrapError(
xerrors.Errorf("unexpected status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))),
"compare branches",
)
}
// Bound the read to MaxDiffSize + overhead for JSON structure.
// The JSON envelope (commits, metadata) adds some overhead beyond
// the raw diff content, so we allow ~10% extra for framing.
maxRead := int64(MaxDiffSize) + int64(MaxDiffSize/10) + 4096
body, err := io.ReadAll(io.LimitReader(resp.Body, maxRead+1))
if err != nil {
return "", g.wrapError(err, "read compare response")
}
if int64(len(body)) > maxRead {
return "", ErrDiffTooLarge
}
var compare compareResponse
if err := json.Unmarshal(body, &compare); err != nil {
return "", g.wrapError(err, "decode compare response")
}
if compare.CompareTimeout {
return "", xerrors.New("gitlab compare timed out; diff may be incomplete")
}
// Reconstruct unified diff from individual file diffs.
var sb strings.Builder
var estimated int
for _, d := range compare.Diffs {
estimated += len(d.Diff) + len(d.OldPath) + len(d.NewPath) + 20
}
if estimated > MaxDiffSize {
return "", ErrDiffTooLarge
}
sb.Grow(estimated)
for _, d := range compare.Diffs {
if d.Collapsed || d.TooLarge {
slog.WarnContext(ctx, "gitlab compare: file diff truncated",
slog.String("path", d.NewPath),
slog.Bool("collapsed", d.Collapsed),
slog.Bool("too_large", d.TooLarge),
)
}
fmt.Fprintf(&sb, "diff --git a/%s b/%s\n", d.OldPath, d.NewPath)
// Add standard unified diff file headers.
switch {
case d.NewFile:
sb.WriteString("--- /dev/null\n")
fmt.Fprintf(&sb, "+++ b/%s\n", d.NewPath)
case d.DeletedFile:
fmt.Fprintf(&sb, "--- a/%s\n", d.OldPath)
sb.WriteString("+++ /dev/null\n")
default:
fmt.Fprintf(&sb, "--- a/%s\n", d.OldPath)
fmt.Fprintf(&sb, "+++ b/%s\n", d.NewPath)
}
sb.WriteString(d.Diff)
// Ensure each file diff ends with a newline.
if len(d.Diff) > 0 && d.Diff[len(d.Diff)-1] != '\n' {
sb.WriteByte('\n')
}
}
result := sb.String()
if len(result) > MaxDiffSize {
return "", ErrDiffTooLarge
}
return result, nil
}
// ParseRepositoryOrigin preserves slashes in owner because GitLab supports
// subgroup paths such as group/subgroup/repo.
//
// TODO: this does not handle GitLab instances installed under a relative URL
// prefix (e.g. https://example.com/gitlab/). See
// https://docs.gitlab.com/install/relative_url/ for details.
func (g *gitlabProvider) ParseRepositoryOrigin(raw string) (owner, repo, normalizedOrigin string, ok bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", "", "", false
}
host := g.webHost()
// Try SSH format: git@HOST:path.git or ssh://git@HOST/path.git
if path, matched := g.parseSSHOrigin(raw, host); matched {
owner, repo = splitOwnerRepo(path)
if owner == "" || repo == "" {
return "", "", "", false
}
normalized := fmt.Sprintf("%s/%s/%s", g.webBaseURL, owner, repo)
return owner, repo, normalized, true
}
// Try HTTPS format.
u, err := url.Parse(raw)
if err != nil {
return "", "", "", false
}
if !strings.EqualFold(u.Host, host) {
return "", "", "", false
}
if u.Scheme != "https" && u.Scheme != "http" {
return "", "", "", false
}
path := strings.TrimPrefix(u.Path, "/")
path = strings.TrimSuffix(path, "/")
path = strings.TrimSuffix(path, ".git")
if path == "" {
return "", "", "", false
}
owner, repo = splitOwnerRepo(path)
if owner == "" || repo == "" {
return "", "", "", false
}
normalized := fmt.Sprintf("%s/%s/%s", g.webBaseURL, owner, repo)
return owner, repo, normalized, true
}
func (g *gitlabProvider) ParsePullRequestURL(raw string) (PRRef, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return PRRef{}, false
}
u, err := url.Parse(raw)
if err != nil {
return PRRef{}, false
}
host := g.webHost()
if !strings.EqualFold(u.Host, host) {
return PRRef{}, false
}
// GitLab MR URLs: /owner/repo/-/merge_requests/123
// or /group/subgroup/repo/-/merge_requests/123
path := strings.TrimPrefix(u.Path, "/")
path = strings.TrimSuffix(path, "/")
// Find "-/merge_requests/NUMBER" in the path.
const mrMarker = "-/merge_requests/"
idx := strings.Index(path, mrMarker)
if idx < 0 {
return PRRef{}, false
}
// Everything before the marker (minus trailing slash) is the project path.
projPath := path[:idx]
projPath = strings.TrimSuffix(projPath, "/")
if projPath == "" {
return PRRef{}, false
}
// The number comes after the marker.
afterMR := path[idx+len(mrMarker):]
// Strip any trailing path segments.
if slashIdx := strings.Index(afterMR, "/"); slashIdx >= 0 {
afterMR = afterMR[:slashIdx]
}
number, err := strconv.Atoi(afterMR)
if err != nil || number <= 0 {
return PRRef{}, false
}
owner, repo := splitOwnerRepo(projPath)
if owner == "" || repo == "" {
return PRRef{}, false
}
return PRRef{
Owner: owner,
Repo: repo,
Number: number,
}, true
}
// NormalizePullRequestURL normalizes a GitLab merge request URL.
func (g *gitlabProvider) NormalizePullRequestURL(raw string) string {
ref, ok := g.ParsePullRequestURL(strings.TrimRight(
strings.TrimSpace(raw),
trailingPunctuation,
))
if !ok {
return ""
}
return g.BuildPullRequestURL(ref)
}
// BuildBranchURL keeps owner and repo unescaped because GitLab owners can
// include subgroup paths with slashes.
func (g *gitlabProvider) BuildBranchURL(owner, repo, branch string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
branch = strings.TrimSpace(branch)
if owner == "" || repo == "" || branch == "" {
return ""
}
return fmt.Sprintf(
"%s/%s/%s/-/tree/%s",
g.webBaseURL,
owner,
repo,
escapePathPreserveSlashes(branch),
)
}
// BuildRepositoryURL keeps owner and repo unescaped because GitLab owners can
// include subgroup paths with slashes.
func (g *gitlabProvider) BuildRepositoryURL(owner, repo string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
if owner == "" || repo == "" {
return ""
}
return fmt.Sprintf("%s/%s/%s", g.webBaseURL, owner, repo)
}
func (g *gitlabProvider) BuildPullRequestURL(ref PRRef) string {
if ref.Owner == "" || ref.Repo == "" || ref.Number <= 0 {
return ""
}
return fmt.Sprintf("%s/%s/%s/-/merge_requests/%d", g.webBaseURL, ref.Owner, ref.Repo, ref.Number)
}
// wrapError converts library errors to our domain errors (e.g. rate limits).
func (g *gitlabProvider) wrapError(err error, action string) error {
if errResp, ok := errors.AsType[*gitlab.ErrorResponse](err); ok {
if rlErr := checkRateLimitError(errResp.Response, g.clock, "RateLimit-Reset"); rlErr != nil {
return rlErr
}
}
return xerrors.Errorf("gitlab %s: %w", action, err)
}
// mapGitLabState maps a GitLab merge request state string to a normalized PRState.
func mapGitLabState(state string) PRState {
switch strings.ToLower(strings.TrimSpace(state)) {
case "opened":
return PRStateOpen
case "merged":
return PRStateMerged
case "closed", "locked":
return PRStateClosed
default:
return PRStateClosed
}
}
// splitOwnerRepo splits a path like "group/subgroup/repo" into
// owner="group/subgroup" and repo="repo". The last segment is always
// the repo name, and everything before it is the owner.
func splitOwnerRepo(path string) (owner, repo string) {
path = strings.TrimPrefix(path, "/")
path = strings.TrimSuffix(path, "/")
if path == "" {
return "", ""
}
lastSlash := strings.LastIndex(path, "/")
if lastSlash < 0 {
// No slash means no owner/repo split possible.
return "", ""
}
owner = path[:lastSlash]
repo = path[lastSlash+1:]
if owner == "" || repo == "" {
return "", ""
}
return owner, repo
}
// parseSSHOrigin attempts to parse an SSH git remote URL for the given host.
// Returns the path (without .git suffix) and true if it matched.
func (g *gitlabProvider) parseSSHOrigin(raw string, host string) (string, bool) {
// Handle ssh://git@HOST/path.git format.
if strings.HasPrefix(raw, "ssh://") {
u, err := url.Parse(raw)
if err != nil {
return "", false
}
// The host in SSH URLs may include a port, so compare case-insensitively.
if !strings.EqualFold(u.Host, host) && !strings.EqualFold(u.Hostname(), hostWithoutPort(host)) {
return "", false
}
path := strings.TrimPrefix(u.Path, "/")
path = strings.TrimSuffix(path, ".git")
path = strings.TrimSuffix(path, "/")
if path == "" {
return "", false
}
return path, true
}
// Handle git@HOST:path.git format (SCP-like syntax).
prefix := "git@" + host + ":"
// Also try matching without port for host comparison.
prefixNoPort := "git@" + hostWithoutPort(host) + ":"
path, ok := strings.CutPrefix(raw, prefix)
if !ok {
path, ok = strings.CutPrefix(raw, prefixNoPort)
}
if !ok {
return "", false
}
path = strings.TrimSuffix(path, ".git")
path = strings.TrimSuffix(path, "/")
if path == "" {
return "", false
}
return path, true
}
// hostWithoutPort strips the port from a host:port string.
func hostWithoutPort(host string) string {
if idx := strings.LastIndex(host, ":"); idx >= 0 {
return host[:idx]
}
return host
}